mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
8 Commits
db2ad11743
...
cd4955d4f3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd4955d4f3 | ||
|
|
c73984b213 | ||
|
|
d8d2e5a2c2 | ||
|
|
b056d7d28d | ||
|
|
95a884ea16 | ||
|
|
b7ae526071 | ||
|
|
cf6d0d1ccc | ||
|
|
709b6355c2 |
@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
|
||||
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
||||
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
||||
|
||||
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
|
||||
**Example YAML Configuration (e.g., inside a filter rule):**
|
||||
|
||||
```yaml
|
||||
# ... inside a filtering configuration section ...
|
||||
|
||||
@ -1,44 +1,47 @@
|
||||
targets:
|
||||
classes:
|
||||
classes:
|
||||
- name: myomys
|
||||
tags:
|
||||
- value: Myotis mystacinus
|
||||
- name: pippip
|
||||
tags:
|
||||
- value: Pipistrellus pipistrellus
|
||||
- name: eptser
|
||||
tags:
|
||||
- value: Eptesicus serotinus
|
||||
- name: rhifer
|
||||
tags:
|
||||
- value: Rhinolophus ferrumequinum
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag: { key: event, value: Echolocation }
|
||||
- name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag: { key: class, value: Unknown }
|
||||
assign_tags:
|
||||
- key: class
|
||||
value: Bat
|
||||
|
||||
filtering:
|
||||
rules:
|
||||
- match_type: all
|
||||
tags:
|
||||
- key: event
|
||||
value: Echolocation
|
||||
- match_type: exclude
|
||||
tags:
|
||||
- key: class
|
||||
value: Unknown
|
||||
classification_targets:
|
||||
- name: myomys
|
||||
tags:
|
||||
- key: class
|
||||
value: Myotis mystacinus
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: class
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: eptser
|
||||
tags:
|
||||
- key: class
|
||||
value: Eptesicus serotinus
|
||||
- name: rhifer
|
||||
tags:
|
||||
- key: class
|
||||
value: Rhinolophus ferrumequinum
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
|
||||
preprocess:
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
samplerate: 256000
|
||||
enabled: True
|
||||
method: "poly"
|
||||
scale: false
|
||||
center: true
|
||||
duration: null
|
||||
|
||||
spectrogram:
|
||||
stft:
|
||||
@ -48,66 +51,66 @@ preprocess:
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
pcen:
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
scale: "amplitude"
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectral_mean_substraction: true
|
||||
peak_normalize: false
|
||||
transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_substraction
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
min_freq: 10000
|
||||
max_freq: 120000
|
||||
top_k_per_sec: 200
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
model:
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
out_channels: 32
|
||||
encoder:
|
||||
layers:
|
||||
- block_type: FreqCoordConvDown
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- block_type: FreqCoordConvDown
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- block_type: LayerGroup
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- block_type: FreqCoordConvDown
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- block_type: ConvBlock
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- block_type: SelfAttention
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- block_type: FreqCoordConvUp
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- block_type: FreqCoordConvUp
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- block_type: LayerGroup
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- block_type: FreqCoordConvUp
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- block_type: ConvBlock
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
train:
|
||||
learning_rate: 0.001
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 40
|
||||
|
||||
dataloaders:
|
||||
train:
|
||||
batch_size: 8
|
||||
@ -115,7 +118,7 @@ train:
|
||||
shuffle: True
|
||||
|
||||
val:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
num_workers: 2
|
||||
|
||||
loss:
|
||||
@ -134,32 +137,34 @@ train:
|
||||
|
||||
logger:
|
||||
logger_type: csv
|
||||
save_dir: outputs/log/
|
||||
name: logs
|
||||
# save_dir: outputs/log/
|
||||
# name: logs
|
||||
|
||||
augmentations:
|
||||
steps:
|
||||
- augmentation_type: mix_audio
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- augmentation_type: add_echo
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
- augmentation_type: scale_volume
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- augmentation_type: warp
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- augmentation_type: mask_time
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- augmentation_type: mask_freq
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
12
justfile
12
justfile
@ -92,19 +92,11 @@ clean-build:
|
||||
clean: clean-build clean-pyc clean-test clean-docs
|
||||
|
||||
# Examples
|
||||
# Preprocess example data.
|
||||
example-preprocess OPTIONS="":
|
||||
batdetect2 preprocess \
|
||||
--base-dir . \
|
||||
--dataset-field datasets.train \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/config.yaml example_data/preprocessed
|
||||
|
||||
# Train on example data.
|
||||
example-train OPTIONS="":
|
||||
batdetect2 train \
|
||||
--val-dir example_data/preprocessed \
|
||||
--val-dataset example_data/dataset.yaml \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/preprocessed
|
||||
example_data/dataset.yaml
|
||||
|
||||
@ -17,7 +17,7 @@ dependencies = [
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.8.1",
|
||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.preprocess import preprocess
|
||||
from batdetect2.cli.evaluate import evaluate_command
|
||||
from batdetect2.cli.train import train_command
|
||||
|
||||
__all__ = [
|
||||
@ -9,7 +9,7 @@ __all__ = [
|
||||
"detect",
|
||||
"data",
|
||||
"train_command",
|
||||
"preprocess",
|
||||
"evaluate_command",
|
||||
]
|
||||
|
||||
|
||||
|
||||
63
src/batdetect2/cli/evaluate.py
Normal file
63
src/batdetect2/cli/evaluate.py
Normal file
@ -0,0 +1,63 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
|
||||
__all__ = ["evaluate_command"]
|
||||
|
||||
|
||||
@cli.command(name="evaluate")
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--output-dir", type=click.Path())
|
||||
@click.option("--workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
workers: Optional[int] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
test_annotations = load_dataset_from_config(test_dataset)
|
||||
logger.debug(
|
||||
"Loaded {num_annotations} test examples",
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
model, train_config = load_model_from_checkpoint(model_path)
|
||||
|
||||
df, results = evaluate(
|
||||
model,
|
||||
test_annotations,
|
||||
config=train_config,
|
||||
num_workers=workers,
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
||||
if output_dir:
|
||||
df.to_csv(output_dir / "results.csv")
|
||||
@ -1,142 +0,0 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.train.preprocess import (
|
||||
TrainPreprocessConfig,
|
||||
load_train_preprocessing_config,
|
||||
preprocess_dataset,
|
||||
)
|
||||
|
||||
__all__ = ["preprocess"]
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.argument(
|
||||
"output",
|
||||
type=click.Path(),
|
||||
)
|
||||
@click.option(
|
||||
"--dataset-field",
|
||||
type=str,
|
||||
help=(
|
||||
"Specifies the key to access the dataset information within the "
|
||||
"dataset configuration file, if the information is nested inside a "
|
||||
"dictionary. If the dataset information is at the top level of the "
|
||||
"config file, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"The main directory where your audio recordings and annotation "
|
||||
"files are stored. This helps the program find your data, "
|
||||
"especially if the paths in your dataset configuration file "
|
||||
"are relative."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
"within the preprocessing configuration file, specify the key "
|
||||
"here to access them. If the preprocessing settings are at the "
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum number of computer cores to use when processing "
|
||||
"your audio data. Using more cores can speed up the preprocessing, "
|
||||
"but don't use more than your computer has available. By default, "
|
||||
"the program will use all available cores."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
base_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Starting preprocessing.")
|
||||
|
||||
output = Path(output)
|
||||
logger.info("Will save outputs to {output}", output=output)
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||
|
||||
if config:
|
||||
logger.info(
|
||||
"Loading preprocessing config from: {config}", config=config
|
||||
)
|
||||
|
||||
conf = (
|
||||
load_train_preprocessing_config(config, field=config_field)
|
||||
if config is not None
|
||||
else TrainPreprocessConfig()
|
||||
)
|
||||
logger.debug(
|
||||
"Preprocessing config:\n{conf}",
|
||||
conf=yaml.dump(conf.model_dump()),
|
||||
)
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
field=dataset_field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||
num_examples=len(dataset),
|
||||
)
|
||||
|
||||
preprocess_dataset(
|
||||
dataset,
|
||||
conf,
|
||||
output=output,
|
||||
max_workers=num_workers,
|
||||
)
|
||||
@ -20,6 +20,8 @@ __all__ = ["train_command"]
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@ -34,6 +36,8 @@ def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Optional[Path] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
train_workers: int = 0,
|
||||
@ -83,4 +87,6 @@ def train_command(
|
||||
model_path=model_path,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
)
|
||||
|
||||
61
src/batdetect2/data/_core.py
Normal file
61
src/batdetect2/data/_core.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import Generic, Protocol, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = [
|
||||
"Registry",
|
||||
]
|
||||
|
||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||
T_Type = TypeVar("T_Type", covariant=True)
|
||||
|
||||
|
||||
class LogicProtocol(Generic[T_Config, T_Type], Protocol):
|
||||
"""A generic protocol for the logic classes (conditions or transforms)."""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: T_Config) -> T_Type: ...
|
||||
|
||||
|
||||
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
|
||||
|
||||
|
||||
class Registry(Generic[T_Type]):
|
||||
"""A generic class to create and manage a registry of items."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._registry = {}
|
||||
|
||||
def register(self, config_cls: Type[T_Config]):
|
||||
"""A decorator factory to register a new item."""
|
||||
fields = config_cls.model_fields
|
||||
|
||||
if "name" not in fields:
|
||||
raise ValueError("Configuration object must have a 'name' field.")
|
||||
|
||||
name = fields["name"].default
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("'name' field must be a string literal.")
|
||||
|
||||
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
|
||||
self._registry[name] = logic_cls
|
||||
return logic_cls
|
||||
|
||||
return decorator
|
||||
|
||||
def build(self, config: BaseModel) -> T_Type:
|
||||
"""Builds a logic instance from a config object."""
|
||||
|
||||
name = getattr(config, "name") # noqa: B009
|
||||
|
||||
if name is None:
|
||||
raise ValueError("Config does not have a name field")
|
||||
|
||||
if name not in self._registry:
|
||||
raise NotImplementedError(
|
||||
f"No {self._name} with name '{name}' is registered."
|
||||
)
|
||||
|
||||
return self._registry[name].from_config(config)
|
||||
@ -14,8 +14,9 @@ format-specific loading function to retrieve the annotations as a standard
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.aoef import (
|
||||
@ -42,10 +43,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
AnnotationFormats = Annotated[
|
||||
Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
"""Type Alias representing all supported data source configurations.
|
||||
|
||||
|
||||
@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import get_term_from_key
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = []
|
||||
@ -92,15 +90,15 @@ def annotation_to_sound_event(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
key=label_key, # type: ignore
|
||||
value=annotation.label,
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
key=event_key, # type: ignore
|
||||
value=annotation.event,
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(individual_key),
|
||||
key=individual_key, # type: ignore
|
||||
value=str(annotation.individual),
|
||||
),
|
||||
],
|
||||
@ -125,7 +123,7 @@ def file_annotation_to_clip(
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
key=label_key, # type: ignore
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
@ -157,7 +155,8 @@ def file_annotation_to_clip_annotation(
|
||||
notes=notes,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key), value=file_annotation.label
|
||||
key=label_key, # type: ignore
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
sound_events=[
|
||||
|
||||
287
src/batdetect2/data/conditions.py
Normal file
287
src/batdetect2/data/conditions.py
Normal file
@ -0,0 +1,287 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, List, Literal, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
_conditions: Registry[SoundEventCondition] = Registry("condition")
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
|
||||
|
||||
@_conditions.register(HasTagConfig)
|
||||
class HasTag:
|
||||
def __init__(self, tag: data.Tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tag in sound_event_annotation.tags
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: HasTagConfig):
|
||||
return cls(tag=config.tag)
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
@_conditions.register(HasAllTagsConfig)
|
||||
class HasAllTags:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = set(tags)
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tags.issubset(sound_event_annotation.tags)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: HasAllTagsConfig):
|
||||
return cls(tags=config.tags)
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
@_conditions.register(HasAnyTagConfig)
|
||||
class HasAnyTag:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = set(tags)
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: HasAnyTagConfig):
|
||||
return cls(tags=config.tags)
|
||||
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
class DurationConfig(BaseConfig):
|
||||
name: Literal["duration"] = "duration"
|
||||
operator: Operator
|
||||
seconds: float
|
||||
|
||||
|
||||
def _build_comparator(
|
||||
operator: Operator, value: float
|
||||
) -> Callable[[float], bool]:
|
||||
if operator == "gt":
|
||||
return lambda x: x > value
|
||||
|
||||
if operator == "gte":
|
||||
return lambda x: x >= value
|
||||
|
||||
if operator == "lt":
|
||||
return lambda x: x < value
|
||||
|
||||
if operator == "lte":
|
||||
return lambda x: x <= value
|
||||
|
||||
if operator == "eq":
|
||||
return lambda x: x == value
|
||||
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
|
||||
|
||||
@_conditions.register(DurationConfig)
|
||||
class Duration:
|
||||
def __init__(self, operator: Operator, seconds: float):
|
||||
self.operator = operator
|
||||
self.seconds = seconds
|
||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
duration = end_time - start_time
|
||||
|
||||
return self._comparator(duration)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: DurationConfig):
|
||||
return cls(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
name: Literal["frequency"] = "frequency"
|
||||
boundary: Literal["low", "high"]
|
||||
operator: Operator
|
||||
hertz: float
|
||||
|
||||
|
||||
@_conditions.register(FrequencyConfig)
|
||||
class Frequency:
|
||||
def __init__(
|
||||
self,
|
||||
operator: Operator,
|
||||
boundary: Literal["low", "high"],
|
||||
hertz: float,
|
||||
):
|
||||
self.operator = operator
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
# Automatically false if geometry does not have a frequency range
|
||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||
return False
|
||||
|
||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
if self.boundary == "low":
|
||||
return self._comparator(low_freq)
|
||||
|
||||
return self._comparator(high_freq)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: FrequencyConfig):
|
||||
return cls(
|
||||
operator=config.operator,
|
||||
boundary=config.boundary,
|
||||
hertz=config.hertz,
|
||||
)
|
||||
|
||||
|
||||
class AllOfConfig(BaseConfig):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@_conditions.register(AllOfConfig)
|
||||
class AllOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return all(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: AllOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return cls(conditions)
|
||||
|
||||
|
||||
class AnyOfConfig(BaseConfig):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: List["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@_conditions.register(AnyOfConfig)
|
||||
class AnyOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return any(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: AnyOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return cls(conditions)
|
||||
|
||||
|
||||
class NotConfig(BaseConfig):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "SoundEventConditionConfig"
|
||||
|
||||
|
||||
@_conditions.register(NotConfig)
|
||||
class Not:
|
||||
def __init__(self, condition: SoundEventCondition):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return not self.condition(sound_event_annotation)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: NotConfig):
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return cls(condition)
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
Union[
|
||||
HasTagConfig,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
DurationConfig,
|
||||
FrequencyConfig,
|
||||
AllOfConfig,
|
||||
AnyOfConfig,
|
||||
NotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
) -> SoundEventCondition:
|
||||
return _conditions.build(config)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
condition: SoundEventCondition,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if condition(sound_event)
|
||||
]
|
||||
)
|
||||
)
|
||||
@ -19,7 +19,7 @@ The core components are:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
@ -31,6 +31,17 @@ from batdetect2.data.annotations import (
|
||||
AnnotationFormats,
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
filter_clip_annotation,
|
||||
)
|
||||
from batdetect2.data.transforms import (
|
||||
ApplyAll,
|
||||
SoundEventTransformConfig,
|
||||
build_sound_event_transform,
|
||||
transform_clip_annotation,
|
||||
)
|
||||
from batdetect2.targets.terms import data_source
|
||||
|
||||
__all__ = [
|
||||
@ -52,79 +63,68 @@ sources.
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
"""Configuration model defining the structure of a BatDetect2 dataset.
|
||||
|
||||
This class is typically loaded from a YAML file and describes the components
|
||||
of the dataset, including metadata and a list of data sources.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
|
||||
description : str
|
||||
A longer description of the dataset's contents, origin, purpose, etc.
|
||||
sources : List[AnnotationFormats]
|
||||
A list defining the different data sources contributing to this
|
||||
dataset. Each item in the list must conform to one of the Pydantic
|
||||
models defined in the `AnnotationFormats` type union. The specific
|
||||
model used for each source is determined by the mandatory `format`
|
||||
field within the source's configuration, allowing BatDetect2 to use the
|
||||
correct parser for different annotation styles.
|
||||
"""
|
||||
"""Configuration model defining the structure of a BatDetect2 dataset."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sources: List[
|
||||
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||
]
|
||||
sources: List[AnnotationFormats]
|
||||
|
||||
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(
|
||||
dataset: DatasetConfig,
|
||||
config: DatasetConfig,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig.
|
||||
|
||||
Iterates through each data source specified in the `dataset_config`,
|
||||
delegates the loading and parsing of that source's annotations to
|
||||
`batdetect2.data.annotations.load_annotated_dataset` (which handles
|
||||
different data formats), and aggregates all resulting `ClipAnnotation`
|
||||
objects into a single flat list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_config : DatasetConfig
|
||||
The configuration object describing the dataset and its sources.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path. If provided, relative paths for
|
||||
metadata files or data directories within the `dataset_config`'s
|
||||
sources might be resolved relative to this directory. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset (List[data.ClipAnnotation])
|
||||
A flat list containing all loaded `ClipAnnotation` metadata objects
|
||||
from all specified sources.
|
||||
|
||||
Raises
|
||||
------
|
||||
Exception
|
||||
Can raise various exceptions during the delegated loading process
|
||||
(`load_annotated_dataset`) if files are not found, cannot be parsed
|
||||
according to the specified format, or other I/O errors occur.
|
||||
"""
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
clip_annotations = []
|
||||
for source in dataset.sources:
|
||||
|
||||
condition = (
|
||||
build_sound_event_condition(config.sound_event_filter)
|
||||
if config.sound_event_filter is not None
|
||||
else None
|
||||
)
|
||||
|
||||
transform = (
|
||||
ApplyAll(
|
||||
[
|
||||
build_sound_event_transform(step)
|
||||
for step in config.sound_event_transforms
|
||||
]
|
||||
)
|
||||
if config.sound_event_transforms
|
||||
else None
|
||||
)
|
||||
|
||||
for source in config.sources:
|
||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||
|
||||
logger.debug(
|
||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||
num_examples=len(annotated_source.clip_annotations),
|
||||
source_name=source.name,
|
||||
)
|
||||
clip_annotations.extend(
|
||||
insert_source_tag(clip_annotation, source)
|
||||
for clip_annotation in annotated_source.clip_annotations
|
||||
)
|
||||
|
||||
for clip_annotation in annotated_source.clip_annotations:
|
||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||
|
||||
if condition is not None:
|
||||
clip_annotation = filter_clip_annotation(
|
||||
clip_annotation,
|
||||
condition,
|
||||
)
|
||||
|
||||
if transform is not None:
|
||||
clip_annotation = transform_clip_annotation(
|
||||
clip_annotation,
|
||||
transform,
|
||||
)
|
||||
|
||||
clip_annotations.append(clip_annotation)
|
||||
|
||||
return clip_annotations
|
||||
|
||||
|
||||
@ -161,7 +161,6 @@ def insert_source_tag(
|
||||
)
|
||||
|
||||
|
||||
# TODO: add documentation
|
||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||
|
||||
|
||||
@ -10,16 +10,8 @@ from batdetect2.typing.targets import TargetProtocol
|
||||
def iterate_over_sound_events(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
apply_filter: bool = True,
|
||||
apply_transform: bool = True,
|
||||
exclude_generic: bool = True,
|
||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||
"""Iterate over sound events in a dataset, applying filtering and
|
||||
transformations.
|
||||
|
||||
This generator function processes sound event annotations from a given
|
||||
dataset, allowing for optional filtering, transformation, and exclusion of
|
||||
unclassifiable (generic) events based on the provided target definitions.
|
||||
"""Iterate over sound events in a dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -29,18 +21,6 @@ def iterate_over_sound_events(
|
||||
targets : TargetProtocol
|
||||
An object implementing the `TargetProtocol`, which provides methods
|
||||
for filtering, transforming, and encoding sound events.
|
||||
apply_filter : bool, optional
|
||||
If True, sound events will be filtered using `targets.filter()`.
|
||||
Only events for which `targets.filter()` returns True will be yielded.
|
||||
Defaults to True.
|
||||
apply_transform : bool, optional
|
||||
If True, sound events will be transformed using `targets.transform()`
|
||||
before being yielded. Defaults to True.
|
||||
exclude_generic : bool, optional
|
||||
If True, sound events that result in a `None` class name after
|
||||
`targets.encode()` will be excluded. This is typically used to
|
||||
filter out events that cannot be mapped to a specific target class.
|
||||
Defaults to True.
|
||||
|
||||
Yields
|
||||
------
|
||||
@ -63,17 +43,9 @@ def iterate_over_sound_events(
|
||||
"""
|
||||
for clip_annotation in dataset:
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
if apply_filter:
|
||||
if not targets.filter(sound_event_annotation):
|
||||
continue
|
||||
|
||||
if apply_transform:
|
||||
sound_event_annotation = targets.transform(
|
||||
sound_event_annotation
|
||||
)
|
||||
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
if class_name is None and exclude_generic:
|
||||
if not targets.filter(sound_event_annotation):
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
|
||||
yield class_name, sound_event_annotation
|
||||
|
||||
250
src/batdetect2/data/transforms.py
Normal file
250
src/batdetect2/data/transforms.py
Normal file
@ -0,0 +1,250 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
|
||||
SoundEventTransform = Callable[
|
||||
[data.SoundEventAnnotation],
|
||||
data.SoundEventAnnotation,
|
||||
]
|
||||
|
||||
_transforms: Registry[SoundEventTransform] = Registry("transform")
|
||||
|
||||
|
||||
class SetFrequencyBoundConfig(BaseConfig):
|
||||
name: Literal["set_frequency"] = "set_frequency"
|
||||
boundary: Literal["low", "high"] = "low"
|
||||
hertz: float
|
||||
|
||||
|
||||
@_transforms.register(SetFrequencyBoundConfig)
|
||||
class SetFrequencyBound:
|
||||
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
sound_event = sound_event_annotation.sound_event
|
||||
geometry = sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return sound_event_annotation
|
||||
|
||||
if not isinstance(geometry, data.BoundingBox):
|
||||
return sound_event_annotation
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.coordinates
|
||||
|
||||
if self.boundary == "low":
|
||||
low_freq = self.hertz
|
||||
high_freq = max(high_freq, low_freq)
|
||||
|
||||
elif self.boundary == "high":
|
||||
high_freq = self.hertz
|
||||
low_freq = min(high_freq, low_freq)
|
||||
|
||||
geometry = data.BoundingBox(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq],
|
||||
)
|
||||
|
||||
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
|
||||
return sound_event_annotation.model_copy(
|
||||
update=dict(sound_event=sound_event)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: SetFrequencyBoundConfig):
|
||||
return cls(hertz=config.hertz, boundary=config.boundary)
|
||||
|
||||
|
||||
class ApplyIfConfig(BaseConfig):
|
||||
name: Literal["apply_if"] = "apply_if"
|
||||
transform: "SoundEventTransformConfig"
|
||||
condition: SoundEventConditionConfig
|
||||
|
||||
|
||||
@_transforms.register(ApplyIfConfig)
|
||||
class ApplyIf:
|
||||
def __init__(
|
||||
self,
|
||||
condition: SoundEventCondition,
|
||||
transform: SoundEventTransform,
|
||||
):
|
||||
self.condition = condition
|
||||
self.transform = transform
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
if not self.condition(sound_event_annotation):
|
||||
return sound_event_annotation
|
||||
|
||||
return self.transform(sound_event_annotation)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ApplyIfConfig):
|
||||
transform = build_sound_event_transform(config.transform)
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return cls(condition=condition, transform=transform)
|
||||
|
||||
|
||||
class ReplaceTagConfig(BaseConfig):
|
||||
name: Literal["replace_tag"] = "replace_tag"
|
||||
original: data.Tag
|
||||
replacement: data.Tag
|
||||
|
||||
|
||||
@_transforms.register(ReplaceTagConfig)
|
||||
class ReplaceTag:
|
||||
def __init__(
|
||||
self,
|
||||
original: data.Tag,
|
||||
replacement: data.Tag,
|
||||
):
|
||||
self.original = original
|
||||
self.replacement = replacement
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag == self.original:
|
||||
tags.append(self.replacement)
|
||||
else:
|
||||
tags.append(tag)
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ReplaceTagConfig):
|
||||
return cls(original=config.original, replacement=config.replacement)
|
||||
|
||||
|
||||
class MapTagValueConfig(BaseConfig):
|
||||
name: Literal["map_tag_value"] = "map_tag_value"
|
||||
tag_key: str
|
||||
value_mapping: Dict[str, str]
|
||||
target_key: Optional[str] = None
|
||||
|
||||
|
||||
@_transforms.register(MapTagValueConfig)
|
||||
class MapTagValue:
|
||||
def __init__(
|
||||
self,
|
||||
tag_key: str,
|
||||
value_mapping: Dict[str, str],
|
||||
target_key: Optional[str] = None,
|
||||
):
|
||||
self.tag_key = tag_key
|
||||
self.value_mapping = value_mapping
|
||||
self.target_key = target_key
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag.key != self.tag_key:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
value = self.value_mapping.get(tag.value)
|
||||
|
||||
if value is None:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
if self.target_key is None:
|
||||
tags.append(tag.model_copy(update=dict(value=value)))
|
||||
else:
|
||||
tags.append(
|
||||
data.Tag(
|
||||
key=self.target_key, # type: ignore
|
||||
value=value,
|
||||
)
|
||||
)
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: MapTagValueConfig):
|
||||
return cls(
|
||||
tag_key=config.tag_key,
|
||||
value_mapping=config.value_mapping,
|
||||
target_key=config.target_key,
|
||||
)
|
||||
|
||||
|
||||
class ApplyAllConfig(BaseConfig):
|
||||
name: Literal["apply_all"] = "apply_all"
|
||||
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
||||
|
||||
|
||||
@_transforms.register(ApplyAllConfig)
|
||||
class ApplyAll:
|
||||
def __init__(self, steps: List[SoundEventTransform]):
|
||||
self.steps = steps
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
for step in self.steps:
|
||||
sound_event_annotation = step(sound_event_annotation)
|
||||
|
||||
return sound_event_annotation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ApplyAllConfig):
|
||||
steps = [build_sound_event_transform(step) for step in config.steps]
|
||||
return cls(steps)
|
||||
|
||||
|
||||
SoundEventTransformConfig = Annotated[
|
||||
Union[
|
||||
SetFrequencyBoundConfig,
|
||||
ReplaceTagConfig,
|
||||
MapTagValueConfig,
|
||||
ApplyIfConfig,
|
||||
ApplyAllConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_transform(
|
||||
config: SoundEventTransformConfig,
|
||||
) -> SoundEventTransform:
|
||||
return _transforms.build(config)
|
||||
|
||||
|
||||
def transform_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
transform: SoundEventTransform,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
transform(sound_event)
|
||||
for sound_event in clip_annotation.sound_events
|
||||
]
|
||||
)
|
||||
)
|
||||
62
src/batdetect2/evaluate/dataframe.py
Normal file
62
src/batdetect2/evaluate/dataframe.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
|
||||
|
||||
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
for match in matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
|
||||
if sound_event_annotation is not None:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
assert geometry is not None
|
||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||
compute_bounds(geometry)
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||
compute_bounds(match.pred_geometry)
|
||||
)
|
||||
|
||||
data.append(
|
||||
{
|
||||
("recording", "uuid"): match.clip.recording.uuid,
|
||||
("clip", "uuid"): match.clip.uuid,
|
||||
("clip", "start_time"): match.clip.start_time,
|
||||
("clip", "end_time"): match.clip.end_time,
|
||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||
if match.sound_event_annotation is not None
|
||||
else None,
|
||||
("gt", "class"): match.gt_class,
|
||||
("gt", "det"): match.gt_det,
|
||||
("gt", "start_time"): gt_start_time,
|
||||
("gt", "end_time"): gt_end_time,
|
||||
("gt", "low_freq"): gt_low_freq,
|
||||
("gt", "high_freq"): gt_high_freq,
|
||||
("pred", "score"): match.pred_score,
|
||||
("pred", "class"): match.pred_class,
|
||||
("pred", "class_score"): match.pred_class_score,
|
||||
("pred", "start_time"): pred_start_time,
|
||||
("pred", "end_time"): pred_end_time,
|
||||
("pred", "low_freq"): pred_low_freq,
|
||||
("pred", "high_freq"): pred_high_freq,
|
||||
("match", "affinity"): match.affinity,
|
||||
**{
|
||||
("pred_class_score", key): value
|
||||
for key, value in match.pred_class_scores.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||
return df
|
||||
100
src/batdetect2/evaluate/evaluate.py
Normal file
100
src/batdetect2/evaluate/evaluate.py
Normal file
@ -0,0 +1,100 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||
from batdetect2.evaluate.match import match_all_predictions
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.plotting.clips import build_audio_loader
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.train import build_val_loader
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: Model,
|
||||
test_annotations: List[data.ClipAnnotation],
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> Tuple[pd.DataFrame, dict]:
|
||||
config = config or FullTrainingConfig()
|
||||
|
||||
audio_loader = build_audio_loader(config.preprocess.audio)
|
||||
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
|
||||
targets = build_targets(config.targets)
|
||||
|
||||
labeller = build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.train.labels,
|
||||
)
|
||||
|
||||
loader = build_val_loader(
|
||||
test_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
dataset: ValidationDataset = loader.dataset # type: ignore
|
||||
|
||||
clip_annotations = []
|
||||
predictions = []
|
||||
|
||||
for batch in loader:
|
||||
outputs = model.detector(batch.spec)
|
||||
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
clips=[
|
||||
clip_annotation.clip for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=model.postprocessor,
|
||||
)
|
||||
|
||||
clip_annotations.extend(clip_annotations)
|
||||
predictions.extend(predictions)
|
||||
|
||||
matches = match_all_predictions(
|
||||
clip_annotations,
|
||||
predictions,
|
||||
targets=targets,
|
||||
config=config.evaluation.match,
|
||||
)
|
||||
|
||||
df = extract_matches_dataframe(matches)
|
||||
|
||||
metrics = [
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
|
||||
results = {
|
||||
name: value
|
||||
for metric in metrics
|
||||
for name, value in metric(matches).items()
|
||||
}
|
||||
|
||||
return df, results
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List, Literal, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -9,7 +8,6 @@ from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import match_geometries as optimal_match
|
||||
from soundevent.geometry import compute_bounds
|
||||
from torch.multiprocessing import Pool
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.typing import (
|
||||
@ -284,7 +282,7 @@ def match_sound_events_and_raw_predictions(
|
||||
config = config or MatchConfig()
|
||||
|
||||
target_sound_events = [
|
||||
targets.transform(sound_event_annotation)
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if targets.filter(sound_event_annotation)
|
||||
and sound_event_annotation.sound_event.geometry is not None
|
||||
@ -430,17 +428,19 @@ def match_all_predictions(
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
logger.info("Matching all annotations and predictions...")
|
||||
with Pool() as p:
|
||||
all_matches = p.starmap(
|
||||
partial(
|
||||
match_sound_events_and_raw_predictions,
|
||||
targets=targets,
|
||||
config=config,
|
||||
),
|
||||
zip(clip_annotations, predictions),
|
||||
return [
|
||||
match
|
||||
for clip_annotation, raw_predictions in zip(
|
||||
clip_annotations,
|
||||
predictions,
|
||||
)
|
||||
|
||||
return [match for matches in all_matches for match in matches]
|
||||
for match in match_sound_events_and_raw_predictions(
|
||||
clip_annotation,
|
||||
raw_predictions,
|
||||
targets=targets,
|
||||
config=config,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -29,7 +29,6 @@ provided here.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from lightning import LightningModule
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
@ -68,7 +67,10 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.typing.models import DetectionModel
|
||||
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
@ -102,7 +104,16 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class Model(LightningModule):
|
||||
class ModelConfig(BaseConfig):
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
@ -114,43 +125,39 @@ class Model(LightningModule):
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
config: ModelConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
self.save_hyperparameters()
|
||||
self.config = config
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||
spec = self.preprocessor(wav)
|
||||
outputs = self.detector(spec)
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
def build_model(config: Optional[ModelConfig] = None):
|
||||
config = config or ModelConfig()
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config.model,
|
||||
)
|
||||
return Model(
|
||||
config=config,
|
||||
detector=detector,
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@ -56,7 +56,7 @@ __all__ = [
|
||||
|
||||
|
||||
class SelfAttentionConfig(BaseConfig):
|
||||
block_type: Literal["SelfAttention"] = "SelfAttention"
|
||||
name: Literal["SelfAttention"] = "SelfAttention"
|
||||
attention_channels: int
|
||||
temperature: float = 1
|
||||
|
||||
@ -178,7 +178,7 @@ class SelfAttention(nn.Module):
|
||||
class ConvConfig(BaseConfig):
|
||||
"""Configuration for a basic ConvBlock."""
|
||||
|
||||
block_type: Literal["ConvBlock"] = "ConvBlock"
|
||||
name: Literal["ConvBlock"] = "ConvBlock"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -300,7 +300,7 @@ class VerticalConv(nn.Module):
|
||||
class FreqCoordConvDownConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvDownBlock."""
|
||||
|
||||
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -390,7 +390,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
class StandardConvDownConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvDownBlock."""
|
||||
|
||||
block_type: Literal["StandardConvDown"] = "StandardConvDown"
|
||||
name: Literal["StandardConvDown"] = "StandardConvDown"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -460,7 +460,7 @@ class StandardConvDownBlock(nn.Module):
|
||||
class FreqCoordConvUpConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvUpBlock."""
|
||||
|
||||
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -569,7 +569,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
class StandardConvUpConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvUpBlock."""
|
||||
|
||||
block_type: Literal["StandardConvUp"] = "StandardConvUp"
|
||||
name: Literal["StandardConvUp"] = "StandardConvUp"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -664,13 +664,13 @@ LayerConfig = Annotated[
|
||||
SelfAttentionConfig,
|
||||
"LayerGroupConfig",
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||
name: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: List[LayerConfig]
|
||||
|
||||
|
||||
@ -686,7 +686,7 @@ def build_layer_from_config(
|
||||
parameters derived from the config and the current pipeline state
|
||||
(`input_height`, `in_channels`).
|
||||
|
||||
It uses the `block_type` field within the `config` object to determine
|
||||
It uses the `name` field within the `config` object to determine
|
||||
which block class to instantiate.
|
||||
|
||||
Parameters
|
||||
@ -698,7 +698,7 @@ def build_layer_from_config(
|
||||
config : LayerConfig
|
||||
A Pydantic configuration object for the desired block (e.g., an
|
||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||
by its `block_type` field.
|
||||
by its `name` field.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -711,11 +711,11 @@ def build_layer_from_config(
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the `config.block_type` does not correspond to a known block type.
|
||||
If the `config.name` does not correspond to a known block type.
|
||||
ValueError
|
||||
If parameters derived from the config are invalid for the block.
|
||||
"""
|
||||
if config.block_type == "ConvBlock":
|
||||
if config.name == "ConvBlock":
|
||||
return (
|
||||
ConvBlock(
|
||||
in_channels=in_channels,
|
||||
@ -727,7 +727,7 @@ def build_layer_from_config(
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.block_type == "FreqCoordConvDown":
|
||||
if config.name == "FreqCoordConvDown":
|
||||
return (
|
||||
FreqCoordConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
@ -740,7 +740,7 @@ def build_layer_from_config(
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.block_type == "StandardConvDown":
|
||||
if config.name == "StandardConvDown":
|
||||
return (
|
||||
StandardConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
@ -752,7 +752,7 @@ def build_layer_from_config(
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.block_type == "FreqCoordConvUp":
|
||||
if config.name == "FreqCoordConvUp":
|
||||
return (
|
||||
FreqCoordConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
@ -765,7 +765,7 @@ def build_layer_from_config(
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "StandardConvUp":
|
||||
if config.name == "StandardConvUp":
|
||||
return (
|
||||
StandardConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
@ -777,7 +777,7 @@ def build_layer_from_config(
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "SelfAttention":
|
||||
if config.name == "SelfAttention":
|
||||
return (
|
||||
SelfAttention(
|
||||
in_channels=in_channels,
|
||||
@ -788,7 +788,7 @@ def build_layer_from_config(
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.block_type == "LayerGroup":
|
||||
if config.name == "LayerGroup":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
@ -804,4 +804,4 @@ def build_layer_from_config(
|
||||
|
||||
return nn.Sequential(*blocks), current_channels, current_height
|
||||
|
||||
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||
raise NotImplementedError(f"Unknown block type {config.name}")
|
||||
|
||||
@ -128,7 +128,7 @@ class Bottleneck(nn.Module):
|
||||
|
||||
BottleneckLayerConfig = Annotated[
|
||||
Union[SelfAttentionConfig,],
|
||||
Field(discriminator="block_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
|
||||
StandardConvUpConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
|
||||
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
|
||||
layers : List[DecoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the decoder sequence. Each item must be a valid block
|
||||
config including a `block_type` field and necessary parameters like
|
||||
config including a `name` field and necessary parameters like
|
||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||
The list must contain at least one layer.
|
||||
"""
|
||||
@ -249,9 +249,9 @@ def build_decoder(
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
"""
|
||||
config = config or DEFAULT_DECODER_CONFIG
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
|
||||
StandardConvDownConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||
|
||||
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the encoder sequence. Each item must be a valid block config
|
||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||
`StandardConvDownConfig`) including a `block_type` field and necessary
|
||||
`StandardConvDownConfig`) including a `name` field and necessary
|
||||
parameters like `out_channels`. Input channels for each layer are
|
||||
inferred sequentially. The list must contain at least one layer.
|
||||
"""
|
||||
@ -287,9 +287,9 @@ def build_encoder(
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
"""
|
||||
if in_channels <= 0 or input_height <= 0:
|
||||
raise ValueError("in_channels and input_height must be positive.")
|
||||
|
||||
@ -1,6 +1,14 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||
|
||||
from typing import Annotated, Callable, List, Literal, Optional, Union
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -306,7 +314,7 @@ class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
transforms: List[SpectrogramTransform] = Field(
|
||||
transforms: Sequence[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
|
||||
@ -1,53 +1,26 @@
|
||||
"""Main entry point for the BatDetect2 Target Definition subsystem.
|
||||
|
||||
This package (`batdetect2.targets`) provides the tools and configurations
|
||||
necessary to define precisely what the BatDetect2 model should learn to detect,
|
||||
classify, and localize from audio data. It involves several conceptual steps,
|
||||
managed through configuration files and culminating in an executable pipeline:
|
||||
|
||||
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
|
||||
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
|
||||
3. **Transformation (`.transform`)**: Modifying tags (standardization,
|
||||
derivation).
|
||||
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
|
||||
target position and size representations, and back.
|
||||
5. **Class Definition (`.classes`)**: Mapping tags to target class names
|
||||
(encoding) and mapping predicted names back to tags (decoding).
|
||||
|
||||
This module exposes the key components for users to configure and utilize this
|
||||
target definition pipeline, primarily through the `TargetConfig` data structure
|
||||
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
|
||||
configured processing steps. The main way to create a functional `Targets`
|
||||
object is via the `build_targets` or `load_targets` functions.
|
||||
"""
|
||||
"""BatDetect2 Target Definition system."""
|
||||
|
||||
from collections import Counter
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventCondition,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
from batdetect2.targets.classes import (
|
||||
ClassesConfig,
|
||||
DEFAULT_CLASSES,
|
||||
DEFAULT_GENERIC_CLASS,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
TargetClass,
|
||||
build_generic_class_tags,
|
||||
TargetClassConfig,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
load_encoder_from_config,
|
||||
)
|
||||
from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
SoundEventFilter,
|
||||
build_sound_event_filter,
|
||||
load_filter_config,
|
||||
load_filter_from_config,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
@ -55,114 +28,53 @@ from batdetect2.targets.rois import (
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermInfo,
|
||||
TermRegistry,
|
||||
call_type,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
individual,
|
||||
register_term,
|
||||
)
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
DeriveTagRule,
|
||||
MapValueRule,
|
||||
ReplaceRule,
|
||||
SoundEventTransformation,
|
||||
TransformConfig,
|
||||
build_transformation_from_config,
|
||||
default_derivation_registry,
|
||||
get_derivation,
|
||||
load_transformation_config,
|
||||
load_transformation_from_config,
|
||||
register_derivation,
|
||||
)
|
||||
from batdetect2.targets.terms import call_type, individual
|
||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassesConfig",
|
||||
"DEFAULT_TARGET_CONFIG",
|
||||
"DeriveTagRule",
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"MapValueRule",
|
||||
"AnchorBBoxMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"ReplaceRule",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"SoundEventTransformation",
|
||||
"TagInfo",
|
||||
"TargetClass",
|
||||
"TargetClassConfig",
|
||||
"TargetConfig",
|
||||
"Targets",
|
||||
"TermInfo",
|
||||
"TransformConfig",
|
||||
"build_generic_class_tags",
|
||||
"build_roi_mapper",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"build_sound_event_filter",
|
||||
"build_transformation_from_config",
|
||||
"call_type",
|
||||
"get_class_names_from_config",
|
||||
"get_derivation",
|
||||
"get_tag_from_info",
|
||||
"get_term_from_key",
|
||||
"individual",
|
||||
"load_classes_config",
|
||||
"load_decoder_from_config",
|
||||
"load_encoder_from_config",
|
||||
"load_filter_config",
|
||||
"load_filter_from_config",
|
||||
"load_target_config",
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
"register_derivation",
|
||||
"register_term",
|
||||
]
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Unified configuration for the entire target definition pipeline.
|
||||
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS)
|
||||
|
||||
This model aggregates the configurations for semantic processing (filtering,
|
||||
transformation, class definition) and geometric processing (ROI mapping).
|
||||
It serves as the primary input for building a complete `Targets` object
|
||||
via `build_targets` or `load_targets`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
filtering : FilterConfig, optional
|
||||
Configuration for filtering sound event annotations based on tags.
|
||||
If None or omitted, no filtering is applied.
|
||||
transforms : TransformConfig, optional
|
||||
Configuration for transforming annotation tags
|
||||
(mapping, derivation, etc.). If None or omitted, no tag transformations
|
||||
are applied.
|
||||
classes : ClassesConfig
|
||||
Configuration defining the specific target classes, their tag matching
|
||||
rules for encoding, their representative tags for decoding
|
||||
(`output_tags`), and the definition of the generic class tags.
|
||||
This section is mandatory.
|
||||
roi : ROIConfig, optional
|
||||
Configuration defining how geometric ROIs (e.g., bounding boxes) are
|
||||
mapped to target representations (reference point, scaled size).
|
||||
Controls `position`, `time_scale`, `frequency_scale`. If None or
|
||||
omitted, default ROI mapping settings are used.
|
||||
"""
|
||||
|
||||
filtering: FilterConfig = Field(default_factory=FilterConfig)
|
||||
transforms: TransformConfig = Field(default_factory=TransformConfig)
|
||||
classes: ClassesConfig = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||
classification_targets: List[TargetClassConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES
|
||||
)
|
||||
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
|
||||
@field_validator("classification_targets")
|
||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: data.PathLike,
|
||||
@ -238,8 +150,7 @@ class Targets(TargetProtocol):
|
||||
roi_mapper: ROITargetMapper,
|
||||
class_names: list[str],
|
||||
generic_class_tags: List[data.Tag],
|
||||
filter_fn: Optional[SoundEventFilter] = None,
|
||||
transform_fn: Optional[SoundEventTransformation] = None,
|
||||
filter_fn: Optional[SoundEventCondition] = None,
|
||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||
):
|
||||
"""Initialize the Targets object.
|
||||
@ -272,7 +183,6 @@ class Targets(TargetProtocol):
|
||||
self._filter_fn = filter_fn
|
||||
self._encode_fn = encode_fn
|
||||
self._decode_fn = decode_fn
|
||||
self._transform_fn = transform_fn
|
||||
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
@ -344,27 +254,6 @@ class Targets(TargetProtocol):
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
|
||||
def transform(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply the configured tag transformations to an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation whose tags should be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags. If no
|
||||
transformations were configured, the original annotation object is
|
||||
returned.
|
||||
"""
|
||||
if self._transform_fn:
|
||||
return self._transform_fn(sound_event)
|
||||
return sound_event
|
||||
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[Position, Size]:
|
||||
@ -430,113 +319,14 @@ class Targets(TargetProtocol):
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_CLASSES = [
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis mystacinus")],
|
||||
name="myomys",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis alcathoe")],
|
||||
name="myoalc",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Eptesicus serotinus")],
|
||||
name="eptser",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Pipistrellus nathusii")],
|
||||
name="pipnat",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Barbastellus barbastellus")],
|
||||
name="barbar",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis nattereri")],
|
||||
name="myonat",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis daubentonii")],
|
||||
name="myodau",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis brandtii")],
|
||||
name="myobra",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
||||
name="pippip",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis bechsteinii")],
|
||||
name="myobec",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
||||
name="pippyg",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
||||
name="rhihip",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||
name="nyclei",
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||
name="rhifer",
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Plecotus auritus")],
|
||||
name="pleaur",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Nyctalus noctula")],
|
||||
name="nycnoc",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Plecotus austriacus")],
|
||||
name="pleaus",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
|
||||
classes=DEFAULT_CLASSES,
|
||||
generic_class=[TagInfo(value="Bat")],
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
filtering=FilterConfig(
|
||||
rules=[
|
||||
FilterRule(
|
||||
match_type="all",
|
||||
tags=[TagInfo(key="event", value="Echolocation")],
|
||||
),
|
||||
FilterRule(
|
||||
match_type="exclude",
|
||||
tags=[
|
||||
TagInfo(key="event", value="Feeding"),
|
||||
TagInfo(key="event", value="Unknown"),
|
||||
TagInfo(key="event", value="Not Bat"),
|
||||
],
|
||||
),
|
||||
]
|
||||
),
|
||||
classes=DEFAULT_CLASSES_CONFIG,
|
||||
classification_targets=DEFAULT_CLASSES,
|
||||
detection_target=DEFAULT_GENERIC_CLASS,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
def build_targets(
|
||||
config: Optional[TargetConfig] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
This factory function takes the unified `TargetConfig` and constructs all
|
||||
@ -550,13 +340,6 @@ def build_targets(
|
||||
----------
|
||||
config : TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use for resolving term keys. Defaults
|
||||
to the global `batdetect2.targets.terms.term_registry`.
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The DerivationRegistry instance to use for resolving derivation
|
||||
function names. Defaults to the global
|
||||
`batdetect2.targets.transform.derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -577,40 +360,18 @@ def build_targets(
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
filter_fn = (
|
||||
build_sound_event_filter(
|
||||
config.filtering,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if config.filtering
|
||||
else None
|
||||
)
|
||||
encode_fn = build_sound_event_encoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
decode_fn = build_sound_event_decoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
transform_fn = (
|
||||
build_transformation_from_config(
|
||||
config.transforms,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
if config.transforms
|
||||
else None
|
||||
)
|
||||
filter_fn = build_sound_event_condition(config.detection_target.match_if)
|
||||
encode_fn = build_sound_event_encoder(config.classification_targets)
|
||||
decode_fn = build_sound_event_decoder(config.classification_targets)
|
||||
|
||||
roi_mapper = build_roi_mapper(config.roi)
|
||||
class_names = get_class_names_from_config(config.classes)
|
||||
generic_class_tags = build_generic_class_tags(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
class_names = get_class_names_from_config(config.classification_targets)
|
||||
|
||||
generic_class_tags = config.detection_target.assign_tags
|
||||
|
||||
roi_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classes.classes
|
||||
for class_config in config.classification_targets
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
@ -621,7 +382,6 @@ def build_targets(
|
||||
class_names=class_names,
|
||||
roi_mapper=roi_mapper,
|
||||
generic_class_tags=generic_class_tags,
|
||||
transform_fn=transform_fn,
|
||||
roi_mapper_overrides=roi_overrides,
|
||||
)
|
||||
|
||||
@ -629,8 +389,6 @@ def build_targets(
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
@ -645,11 +403,6 @@ def load_targets(
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing
|
||||
the target configuration. If None, the entire file content is used.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use. Defaults to the global default.
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The DerivationRegistry instance to use. Defaults to the global
|
||||
default.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -670,11 +423,7 @@ def load_targets(
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return build_targets(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
return build_targets(config)
|
||||
|
||||
|
||||
def iterate_encoded_sound_events(
|
||||
@ -690,8 +439,6 @@ def iterate_encoded_sound_events(
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
sound_event = targets.transform(sound_event)
|
||||
|
||||
class_name = targets.encode_class(sound_event)
|
||||
position, size = targets.encode_roi(sound_event)
|
||||
|
||||
|
||||
@ -1,253 +1,172 @@
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.rois import ROIMapperConfig
|
||||
from batdetect2.targets.terms import (
|
||||
GENERIC_CLASS_KEY,
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data.conditions import (
|
||||
AllOfConfig,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
NotConfig,
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SPECIES_LIST",
|
||||
"build_generic_class_tags",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"get_class_names_from_config",
|
||||
"load_classes_config",
|
||||
"load_decoder_from_config",
|
||||
"load_encoder_from_config",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_SPECIES_LIST = [
|
||||
"Barbastella barbastellus",
|
||||
"Eptesicus serotinus",
|
||||
"Myotis alcathoe",
|
||||
"Myotis bechsteinii",
|
||||
"Myotis brandtii",
|
||||
"Myotis daubentonii",
|
||||
"Myotis mystacinus",
|
||||
"Myotis nattereri",
|
||||
"Nyctalus leisleri",
|
||||
"Nyctalus noctula",
|
||||
"Pipistrellus nathusii",
|
||||
"Pipistrellus pipistrellus",
|
||||
"Pipistrellus pygmaeus",
|
||||
"Plecotus auritus",
|
||||
"Plecotus austriacus",
|
||||
"Rhinolophus ferrumequinum",
|
||||
"Rhinolophus hipposideros",
|
||||
]
|
||||
"""A default list of common bat species names found in the UK."""
|
||||
|
||||
|
||||
class TargetClass(BaseConfig):
|
||||
"""Defines criteria for encoding annotations and decoding predictions.
|
||||
|
||||
Each instance represents one potential output class for the classification
|
||||
model. It specifies:
|
||||
1. A unique `name` for the class.
|
||||
2. The tag conditions (`tags` and `match_type`) an annotation must meet to
|
||||
be assigned this class name during training data preparation (encoding).
|
||||
3. An optional, alternative set of tags (`output_tags`) to be used when
|
||||
converting a model's prediction of this class name back into annotation
|
||||
tags (decoding).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
The unique name assigned to this target class (e.g., 'pippip',
|
||||
'myodau', 'noise'). This name is used as the label during model
|
||||
training and is the expected output from the model's prediction.
|
||||
Should be unique across all TargetClass definitions in a configuration.
|
||||
tags : List[TagInfo]
|
||||
A list of one or more tags (defined using `TagInfo`) used to identify
|
||||
if an existing annotation belongs to this class during encoding (data
|
||||
preparation for training). The `match_type` attribute determines how
|
||||
these tags are evaluated.
|
||||
match_type : Literal["all", "any"], default="all"
|
||||
Determines how the `tags` list is evaluated during encoding:
|
||||
- "all": The annotation must have *all* the tags listed to match.
|
||||
- "any": The annotation must have *at least one* of the tags listed
|
||||
to match.
|
||||
output_tags: Optional[List[TagInfo]], default=None
|
||||
An optional list of tags (defined using `TagInfo`) to be assigned to a
|
||||
new annotation when the model predicts this class `name`. If `None`
|
||||
(default), the tags listed in the `tags` field will be used for
|
||||
decoding. If provided, this list overrides the `tags` field for the
|
||||
purpose of decoding predictions back into meaningful annotation tags.
|
||||
This allows, for example, training on broader categories but decoding
|
||||
to more specific representative tags.
|
||||
"""
|
||||
class TargetClassConfig(BaseConfig):
|
||||
"""Defines a target class of sound events."""
|
||||
|
||||
name: str
|
||||
tags: List[TagInfo] = Field(min_length=1)
|
||||
match_type: Literal["all", "any"] = Field(default="all")
|
||||
output_tags: Optional[List[TagInfo]] = None
|
||||
roi: Optional[ROIMapperConfig] = None
|
||||
|
||||
|
||||
def _get_default_classes() -> List[TargetClass]:
|
||||
"""Generate a list of default target classes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[TargetClass]
|
||||
A list of TargetClass objects, one for each species in
|
||||
DEFAULT_SPECIES_LIST. The class names are simplified versions of the
|
||||
species names.
|
||||
"""
|
||||
return [
|
||||
TargetClass(
|
||||
name=_get_default_class_name(value),
|
||||
tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)],
|
||||
)
|
||||
for value in DEFAULT_SPECIES_LIST
|
||||
]
|
||||
|
||||
|
||||
def _get_default_class_name(species: str) -> str:
|
||||
"""Generate a default class name from a species name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
species : str
|
||||
The species name (e.g., "Myotis daubentonii").
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
A simplified class name (e.g., "myodau").
|
||||
The genus and species names are converted to lowercase,
|
||||
the first three letters of each are taken, and concatenated.
|
||||
"""
|
||||
genus, species = species.strip().split(" ")
|
||||
return f"{genus.lower()[:3]}{species.lower()[:3]}"
|
||||
|
||||
|
||||
def _get_default_generic_class() -> List[TagInfo]:
|
||||
"""Generate the default list of TagInfo objects for the generic class.
|
||||
|
||||
Provides a default set of tags used to represent the generic "Bat" category
|
||||
when decoding predictions that didn't match a specific class.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[TagInfo]
|
||||
A list containing default TagInfo objects, typically representing
|
||||
`call_type: Echolocation` and `order: Chiroptera`.
|
||||
"""
|
||||
return [
|
||||
TagInfo(key="call_type", value="Echolocation"),
|
||||
TagInfo(key="order", value="Chiroptera"),
|
||||
]
|
||||
|
||||
|
||||
class ClassesConfig(BaseConfig):
|
||||
"""Configuration defining target classes and the generic fallback category.
|
||||
|
||||
Holds the ordered list of specific target class definitions (`TargetClass`)
|
||||
and defines the tags representing the generic category for sounds that pass
|
||||
filtering but do not match any specific class.
|
||||
|
||||
The order of `TargetClass` objects in the `classes` list defines the
|
||||
priority for classification during encoding. The system checks annotations
|
||||
against these definitions sequentially and assigns the name of the *first*
|
||||
matching class.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
classes : List[TargetClass]
|
||||
An ordered list of specific target class definitions. The order
|
||||
determines matching priority (first match wins). Defaults to a
|
||||
standard set of classes via `get_default_classes`.
|
||||
generic_class : List[TagInfo]
|
||||
A list of tags defining the "generic" or "unclassified but relevant"
|
||||
category (e.g., representing a generic 'Bat' call that wasn't
|
||||
assigned to a specific species). These tags are typically assigned
|
||||
during decoding when a sound event was detected and passed filtering
|
||||
but did not match any specific class rule defined in the `classes` list.
|
||||
Defaults to a standard set of tags via `get_default_generic_class`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If validation fails (e.g., non-unique class names in the `classes`
|
||||
list).
|
||||
|
||||
Notes
|
||||
-----
|
||||
- It is crucial that the `name` attribute of each `TargetClass` in the
|
||||
`classes` list is unique. This configuration includes a validator to
|
||||
enforce this uniqueness.
|
||||
- The `generic_class` tags provide a baseline identity for relevant sounds
|
||||
that don't fit into more specific defined categories.
|
||||
"""
|
||||
|
||||
classes: List[TargetClass] = Field(default_factory=_get_default_classes)
|
||||
|
||||
generic_class: List[TagInfo] = Field(
|
||||
default_factory=_get_default_generic_class
|
||||
condition_input: Optional[SoundEventConditionConfig] = Field(
|
||||
alias="match_if",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("classes")
|
||||
def check_unique_class_names(cls, v: List[TargetClass]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||
|
||||
roi: Optional[ROIMapperConfig] = None
|
||||
|
||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _process_tags(self) -> "TargetClassConfig":
|
||||
if self.tags and self.condition_input:
|
||||
raise ValueError("Use either 'tags' or 'match_if', not both.")
|
||||
|
||||
if self.condition_input is not None:
|
||||
self._match_if = self.condition_input
|
||||
return self
|
||||
|
||||
if self.tags is None:
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
f"Class '{self.name}' must have a 'tags' or 'match_if' rule."
|
||||
)
|
||||
return v
|
||||
|
||||
self._match_if = HasAllTagsConfig(tags=self.tags)
|
||||
|
||||
if not self.assign_tags:
|
||||
self.assign_tags = self.tags
|
||||
|
||||
return self
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def match_if(self) -> SoundEventConditionConfig:
|
||||
return self._match_if
|
||||
|
||||
|
||||
def is_target_class(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
match_all: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a sound event annotation matches a set of required tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
required_tags : Set[data.Tag]
|
||||
A set of `soundevent.data.Tag` objects that define the class criteria.
|
||||
match_all : bool, default=True
|
||||
If True, checks if *all* `required_tags` are present in the
|
||||
annotation's tags (subset check). If False, checks if *at least one*
|
||||
of the `required_tags` is present (intersection check).
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation meets the tag criteria, False otherwise.
|
||||
"""
|
||||
annotation_tags = set(sound_event_annotation.tags)
|
||||
|
||||
if match_all:
|
||||
return tags <= annotation_tags
|
||||
|
||||
return bool(tags & annotation_tags)
|
||||
DEFAULT_GENERIC_CLASS = TargetClassConfig(
|
||||
name="bat",
|
||||
match_if=AllOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")),
|
||||
NotConfig(
|
||||
condition=HasAnyTagConfig(
|
||||
tags=[
|
||||
data.Tag(key="event", value="Feeding"),
|
||||
data.Tag(key="event", value="Unknown"),
|
||||
data.Tag(key="event", value="Not Bat"),
|
||||
]
|
||||
)
|
||||
),
|
||||
]
|
||||
),
|
||||
assign_tags=[
|
||||
data.Tag(key="call_type", value="Echolocation"),
|
||||
data.Tag(key="order", value="Chiroptera"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
||||
DEFAULT_CLASSES = [
|
||||
TargetClassConfig(
|
||||
name="myomys",
|
||||
tags=[data.Tag(key="class", value="Myotis mystacinus")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="myoalc",
|
||||
tags=[data.Tag(key="class", value="Myotis alcathoe")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="eptser",
|
||||
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="pipnat",
|
||||
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="barbar",
|
||||
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="myonat",
|
||||
tags=[data.Tag(key="class", value="Myotis nattereri")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="myodau",
|
||||
tags=[data.Tag(key="class", value="Myotis daubentonii")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="myobra",
|
||||
tags=[data.Tag(key="class", value="Myotis brandtii")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="pippip",
|
||||
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="myobec",
|
||||
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="pippyg",
|
||||
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="rhihip",
|
||||
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="nyclei",
|
||||
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="rhifer",
|
||||
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="pleaur",
|
||||
tags=[data.Tag(key="class", value="Plecotus auritus")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="nycnoc",
|
||||
tags=[data.Tag(key="class", value="Nyctalus noctula")],
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="pleaus",
|
||||
tags=[data.Tag(key="class", value="Plecotus austriacus")],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_class_names_from_config(configs: List[TargetClassConfig]) -> List[str]:
|
||||
"""Extract the list of class names from a ClassesConfig object.
|
||||
|
||||
Parameters
|
||||
@ -260,340 +179,60 @@ def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
||||
List[str]
|
||||
An ordered list of unique class names defined in the configuration.
|
||||
"""
|
||||
return [class_info.name for class_info in config.classes]
|
||||
|
||||
|
||||
def _encode_with_multiple_classifiers(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]],
|
||||
) -> Optional[str]:
|
||||
"""Encode an annotation by checking against a list of classifiers.
|
||||
|
||||
Internal helper function used by the `SoundEventEncoder`. It iterates
|
||||
through the provided list of (class_name, classifier_function) pairs.
|
||||
Returns the name associated with the first classifier function that
|
||||
returns True for the given annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to encode.
|
||||
classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]]
|
||||
An ordered list where each tuple contains a class name and a function
|
||||
that returns True if the annotation matches that class. The order
|
||||
determines priority.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the first matching class, or None if no classifier matches.
|
||||
"""
|
||||
for class_name, classifier in classifiers:
|
||||
if classifier(sound_event_annotation):
|
||||
return class_name
|
||||
|
||||
return None
|
||||
return [class_info.name for class_info in configs]
|
||||
|
||||
|
||||
def build_sound_event_encoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
configs: List[TargetClassConfig],
|
||||
) -> SoundEventEncoder:
|
||||
"""Build a sound event encoder function from the classes configuration.
|
||||
"""Build a sound event encoder function from the classes configuration."""
|
||||
conditions = {
|
||||
class_config.name: build_sound_event_condition(class_config.match_if)
|
||||
for class_config in configs
|
||||
}
|
||||
|
||||
The returned encoder function iterates through the class definitions in the
|
||||
order specified in the config. It assigns an annotation the name of the
|
||||
first class definition it matches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used to look up term keys specified in the
|
||||
`TagInfo` objects within the configuration. Defaults to the global
|
||||
`batdetect2.targets.terms.registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventEncoder
|
||||
A callable function that takes a `SoundEventAnnotation` and returns
|
||||
an optional string representing the matched class name, or None if no
|
||||
class matches.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry`.
|
||||
"""
|
||||
binary_classifiers = [
|
||||
(
|
||||
class_info.name,
|
||||
partial(
|
||||
is_target_class,
|
||||
tags={
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in class_info.tags
|
||||
},
|
||||
match_all=class_info.match_type == "all",
|
||||
),
|
||||
)
|
||||
for class_info in config.classes
|
||||
]
|
||||
|
||||
return partial(
|
||||
_encode_with_multiple_classifiers,
|
||||
classifiers=binary_classifiers,
|
||||
)
|
||||
return SoundEventClassifier(conditions)
|
||||
|
||||
|
||||
def _decode_class(
|
||||
name: str,
|
||||
mapping: Dict[str, List[data.Tag]],
|
||||
raise_on_error: bool = True,
|
||||
) -> List[data.Tag]:
|
||||
"""Decode a class name into a list of representative tags using a mapping.
|
||||
class SoundEventClassifier:
|
||||
def __init__(self, mapping: Dict[str, SoundEventCondition]):
|
||||
self.mapping = mapping
|
||||
|
||||
Internal helper function used by the `SoundEventDecoder`. Looks up the
|
||||
provided class `name` in the `mapping` dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The class name to decode.
|
||||
mapping : Dict[str, List[data.Tag]]
|
||||
A dictionary mapping class names to lists of `soundevent.data.Tag`
|
||||
objects.
|
||||
raise_on_error : bool, default=True
|
||||
If True, raises a ValueError if the `name` is not found in the
|
||||
`mapping`. If False, returns an empty list if the `name` is not found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags associated with the class name, or an empty list if
|
||||
not found and `raise_on_error` is False.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `name` is not found in `mapping` and `raise_on_error` is True.
|
||||
"""
|
||||
if name not in mapping and raise_on_error:
|
||||
raise ValueError(f"Class {name} not found in mapping.")
|
||||
|
||||
if name not in mapping:
|
||||
return []
|
||||
|
||||
return mapping[name]
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
for name, condition in self.mapping.items():
|
||||
if condition(sound_event_annotation):
|
||||
return name
|
||||
|
||||
|
||||
def build_sound_event_decoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
configs: List[TargetClassConfig],
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Build a sound event decoder function from the classes configuration.
|
||||
|
||||
Creates a callable `SoundEventDecoder` that maps a class name string
|
||||
back to a list of representative `soundevent.data.Tag` objects based on
|
||||
the `ClassesConfig`. It uses the `output_tags` field if provided in a
|
||||
`TargetClass`, otherwise falls back to the `tags` field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used to look up term keys. Defaults to the
|
||||
global `batdetect2.targets.terms.registry`.
|
||||
raise_on_unmapped : bool, default=False
|
||||
If True, the returned decoder function will raise a ValueError if asked
|
||||
to decode a class name that is not in the configuration. If False, it
|
||||
will return an empty list for unmapped names.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventDecoder
|
||||
A callable function that takes a class name string and returns a list
|
||||
of `soundevent.data.Tag` objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term key specified in the configuration (`output_tags`, `tags`, or
|
||||
`generic_class`) is not found in the provided `term_registry`.
|
||||
"""
|
||||
mapping = {}
|
||||
for class_info in config.classes:
|
||||
tags_to_use = (
|
||||
class_info.output_tags
|
||||
if class_info.output_tags is not None
|
||||
else class_info.tags
|
||||
)
|
||||
mapping[class_info.name] = [
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in tags_to_use
|
||||
]
|
||||
|
||||
return partial(
|
||||
_decode_class,
|
||||
mapping=mapping,
|
||||
raise_on_error=raise_on_unmapped,
|
||||
)
|
||||
"""Build a sound event decoder function from the classes configuration."""
|
||||
mapping = {
|
||||
class_config.name: class_config.assign_tags for class_config in configs
|
||||
}
|
||||
return TagDecoder(mapping, raise_on_unknown=raise_on_unmapped)
|
||||
|
||||
|
||||
def build_generic_class_tags(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Tag]:
|
||||
"""Extract and build the list of tags for the generic class from config.
|
||||
class TagDecoder:
|
||||
def __init__(
|
||||
self,
|
||||
mapping: Dict[str, List[data.Tag]],
|
||||
raise_on_unknown: bool = True,
|
||||
):
|
||||
self.mapping = mapping
|
||||
self.raise_on_unknown = raise_on_unknown
|
||||
|
||||
Converts the list of `TagInfo` objects defined in `config.generic_class`
|
||||
into a list of `soundevent.data.Tag` objects using the term registry.
|
||||
def __call__(self, class_name: str) -> List[data.Tag]:
|
||||
tags = self.mapping.get(class_name)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded classes configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance for term lookups. Defaults to the global
|
||||
`batdetect2.targets.terms.registry`.
|
||||
if tags is None:
|
||||
if self.raise_on_unknown:
|
||||
raise ValueError("Invalid class name")
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of fully constructed tags representing the generic class.
|
||||
tags = []
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term key specified in `config.generic_class` is not found in the
|
||||
provided `term_registry`.
|
||||
"""
|
||||
return [
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in config.generic_class
|
||||
]
|
||||
|
||||
|
||||
def load_classes_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> ClassesConfig:
|
||||
"""Load the target classes configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
"""
|
||||
return load_config(path, schema=ClassesConfig, field=field)
|
||||
|
||||
|
||||
def load_encoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Load a class encoder function directly from a configuration file.
|
||||
|
||||
This is a convenience function that combines loading the `ClassesConfig`
|
||||
from a file and building the final `SoundEventEncoder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used for term lookups. Defaults to the
|
||||
global `batdetect2.targets.terms.registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventEncoder
|
||||
The final encoder function ready to classify annotations.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_sound_event_encoder(config, term_registry=term_registry)
|
||||
|
||||
|
||||
def load_decoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Load a class decoder function directly from a configuration file.
|
||||
|
||||
This is a convenience function that combines loading the `ClassesConfig`
|
||||
from a file and building the final `SoundEventDecoder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used for term lookups. Defaults to the
|
||||
global `batdetect2.targets.terms.registry`.
|
||||
raise_on_unmapped : bool, default=False
|
||||
If True, the returned decoder function will raise a ValueError if asked
|
||||
to decode a class name that is not in the configuration. If False, it
|
||||
will return an empty list for unmapped names.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventDecoder
|
||||
The final decoder function ready to convert class names back into tags.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_sound_event_decoder(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
raise_on_unmapped=raise_on_unmapped,
|
||||
)
|
||||
return tags
|
||||
|
||||
@ -1,307 +0,0 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import List, Literal, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
)
|
||||
from batdetect2.typing.targets import SoundEventFilter
|
||||
|
||||
__all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"build_sound_event_filter",
|
||||
"build_filter_from_rule",
|
||||
"load_filter_config",
|
||||
"load_filter_from_config",
|
||||
]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FilterRule(BaseConfig):
|
||||
"""Defines a single rule for filtering sound event annotations.
|
||||
|
||||
Based on the `match_type`, this rule checks if the tags associated with a
|
||||
sound event annotation meet certain criteria relative to the `tags` list
|
||||
defined in this rule.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
match_type : Literal["any", "all", "exclude", "equal"]
|
||||
Determines how the `tags` list is used:
|
||||
- "any": Pass if the annotation has at least one tag from the list.
|
||||
- "all": Pass if the annotation has all tags from the list (it can
|
||||
have others too).
|
||||
- "exclude": Pass if the annotation has none of the tags from the list.
|
||||
- "equal": Pass if the annotation's tags are exactly the same set as
|
||||
provided in the list.
|
||||
tags : List[TagInfo]
|
||||
A list of tags (defined using TagInfo for configuration) that this
|
||||
rule operates on.
|
||||
"""
|
||||
|
||||
match_type: Literal["any", "all", "exclude", "equal"]
|
||||
tags: List[TagInfo]
|
||||
|
||||
|
||||
def has_any_tag(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation has at least one of the specified tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags to look for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation has one or more tags from the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return bool(tags & sound_event_tags)
|
||||
|
||||
|
||||
def contains_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation contains all of the specified tags.
|
||||
|
||||
The annotation may have additional tags beyond those specified.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags that must all be present in the annotation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation's tags are a superset of the specified tags,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return tags <= sound_event_tags
|
||||
|
||||
|
||||
def does_not_have_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
):
|
||||
"""Check if the annotation has none of the specified tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags that must *not* be present in the annotation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation has zero tags in common with the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
return not has_any_tag(sound_event_annotation, tags)
|
||||
|
||||
|
||||
def equal_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation's tags are exactly equal to the specified set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The exact set of tags the annotation must have.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation's tags set is identical to the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return tags == sound_event_tags
|
||||
|
||||
|
||||
def build_filter_from_rule(
|
||||
rule: FilterRule,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Creates a callable filter function from a single FilterRule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rule : FilterRule
|
||||
The filter rule configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
A function that takes a SoundEventAnnotation and returns True if it
|
||||
passes the rule, False otherwise.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the rule contains an invalid `match_type`.
|
||||
"""
|
||||
tag_set = {
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in rule.tags
|
||||
}
|
||||
|
||||
if rule.match_type == "any":
|
||||
return partial(has_any_tag, tags=tag_set)
|
||||
|
||||
if rule.match_type == "all":
|
||||
return partial(contains_tags, tags=tag_set)
|
||||
|
||||
if rule.match_type == "exclude":
|
||||
return partial(does_not_have_tags, tags=tag_set)
|
||||
|
||||
if rule.match_type == "equal":
|
||||
return partial(equal_tags, tags=tag_set)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid match type {rule.match_type}. Valid types "
|
||||
"are: 'any', 'all', 'exclude' and 'equal'"
|
||||
)
|
||||
|
||||
|
||||
def _passes_all_filters(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
filters: List[SoundEventFilter],
|
||||
) -> bool:
|
||||
"""Check if the annotation passes all provided filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
filters : List[SoundEventFilter]
|
||||
A list of filter functions to apply.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation passes all filters, False otherwise.
|
||||
"""
|
||||
for filter_fn in filters:
|
||||
if not filter_fn(sound_event_annotation):
|
||||
logging.debug(
|
||||
f"Sound event annotation {sound_event_annotation.uuid} "
|
||||
f"excluded due to rule {filter_fn}",
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class FilterConfig(BaseConfig):
|
||||
"""Configuration model for defining a list of filter rules.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rules : List[FilterRule]
|
||||
A list of FilterRule objects. An annotation must pass all rules in
|
||||
this list to be considered valid by the filter built from this config.
|
||||
"""
|
||||
|
||||
rules: List[FilterRule] = Field(default_factory=list)
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
config: FilterConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Builds a merged filter function from a FilterConfig object.
|
||||
|
||||
Creates individual filter functions for each rule in the configuration
|
||||
and merges them using AND logic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : FilterConfig
|
||||
The configuration object containing the list of filter rules.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
A single callable filter function that applies all defined rules.
|
||||
"""
|
||||
filters = [
|
||||
build_filter_from_rule(rule, term_registry=term_registry)
|
||||
for rule in config.rules
|
||||
]
|
||||
return partial(_passes_all_filters, filters=filters)
|
||||
|
||||
|
||||
def load_filter_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> FilterConfig:
|
||||
"""Loads the filter configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : Optional[str], optional
|
||||
If the filter configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FilterConfig
|
||||
The loaded and validated filter configuration object.
|
||||
"""
|
||||
return load_config(path, schema=FilterConfig, field=field)
|
||||
|
||||
|
||||
def load_filter_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Loads filter configuration from a file and builds the filter function.
|
||||
|
||||
This is a convenience function that combines loading the configuration
|
||||
and building the final callable filter function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : Optional[str], optional
|
||||
If the filter configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
The final merged filter function ready to be used.
|
||||
"""
|
||||
config = load_filter_config(path=path, field=field)
|
||||
return build_sound_event_filter(config, term_registry=term_registry)
|
||||
@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
|
||||
*geometric* aspect of target definition from *semantic* classification.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
|
||||
from typing import Annotated, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
@ -30,7 +30,7 @@ from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.typing.targets import Position, Size
|
||||
from batdetect2.typing.targets import Position, ROITargetMapper, Size
|
||||
from batdetect2.utils.arrays import spec_to_xarray
|
||||
|
||||
__all__ = [
|
||||
@ -83,73 +83,6 @@ DEFAULT_ANCHOR = "bottom-left"
|
||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
Specifies the `encode` and `decode` methods required for converting a
|
||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||
position and a size vector) and for recovering an approximate ROI from that
|
||||
representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the dimensions in the `Size` array
|
||||
returned by `encode` and expected by `decode`.
|
||||
"""
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||
"""Encode a SoundEvent's geometry into a position and size.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event, which must have a geometry attribute.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
A tuple containing:
|
||||
- The reference position as (time, frequency) coordinates.
|
||||
- A NumPy array with the calculated size dimensions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the sound event does not have a geometry.
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Decode a position and size back into a geometric ROI.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and size
|
||||
dimensions and reconstructs a geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Position
|
||||
The reference position (time, frequency).
|
||||
size : Size
|
||||
NumPy array containing the size dimensions, matching the order
|
||||
and meaning specified by `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry, typically a `BoundingBox`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the `size` array has an unexpected shape or if reconstruction
|
||||
fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AnchorBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `AnchorBBoxMapper`.
|
||||
|
||||
@ -475,7 +408,10 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
|
||||
|
||||
ROIMapperConfig = Annotated[
|
||||
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||
Union[
|
||||
AnchorBBoxMapperConfig,
|
||||
PeakEnergyBBoxMapperConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""A discriminated union of all supported ROI mapper configurations.
|
||||
@ -553,7 +489,7 @@ def _build_bounding_box(
|
||||
) -> data.BoundingBox:
|
||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||
|
||||
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||
Internal helper for `BBoxEncoder.decode`. Calculates the box
|
||||
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||
center, corner).
|
||||
|
||||
@ -1,34 +1,11 @@
|
||||
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
||||
"""Manages the vocabulary for defining training targets."""
|
||||
|
||||
This module provides the necessary tools to declare, register, and manage the
|
||||
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
||||
sub-package. It establishes a consistent vocabulary for filtering,
|
||||
transforming, and classifying sound events based on their annotations (Tags).
|
||||
|
||||
The core component is the `TermRegistry`, which maps unique string keys
|
||||
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
||||
terms using simple, consistent keys in configuration files and code.
|
||||
|
||||
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
||||
programmatically, or loaded from external configuration files (e.g., YAML).
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from inspect import getmembers
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"data_source",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
# The default key used to reference the 'generic_class' term.
|
||||
@ -96,430 +73,3 @@ terms.register_term_set(
|
||||
),
|
||||
override_existing=True,
|
||||
)
|
||||
|
||||
|
||||
class TermRegistry(Mapping[str, data.Term]):
|
||||
"""Manages a registry mapping unique keys to Term definitions.
|
||||
|
||||
This class acts as the central repository for the vocabulary of terms
|
||||
used within the target definition process. It allows registering terms
|
||||
with simple string keys and retrieving them consistently.
|
||||
"""
|
||||
|
||||
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
||||
"""Initializes the TermRegistry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
terms : dict[str, soundevent.data.Term], optional
|
||||
An optional dictionary of initial key-to-Term mappings
|
||||
to populate the registry with. Defaults to an empty registry.
|
||||
"""
|
||||
self._terms: Dict[str, data.Term] = terms or {}
|
||||
|
||||
def __getitem__(self, key: str) -> data.Term:
|
||||
return self._terms[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._terms)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._terms)
|
||||
|
||||
def add_term(self, key: str, term: data.Term) -> None:
|
||||
"""Adds a Term object to the registry with the specified key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key to associate with the term.
|
||||
term : soundevent.data.Term
|
||||
The soundevent.data.Term object to register.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term with the provided key already exists in the
|
||||
registry.
|
||||
"""
|
||||
if key in self._terms:
|
||||
raise KeyError("A term with the provided key already exists.")
|
||||
|
||||
self._terms[key] = term
|
||||
|
||||
def get_term(self, key: str) -> data.Term:
|
||||
"""Retrieves a registered term by its unique key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key of the term to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If no term with the specified key is found, with a
|
||||
helpful message suggesting listing available keys.
|
||||
"""
|
||||
try:
|
||||
return self._terms[key]
|
||||
except KeyError as err:
|
||||
raise KeyError(
|
||||
"No term found for key "
|
||||
f"'{key}'. Ensure it is registered or loaded. "
|
||||
f"Available keys: {', '.join(self.get_keys())}"
|
||||
) from err
|
||||
|
||||
def add_custom_term(
|
||||
self,
|
||||
key: str,
|
||||
name: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
definition: Optional[str] = None,
|
||||
) -> data.Term:
|
||||
"""Creates a new Term from attributes and adds it to the registry.
|
||||
|
||||
This is useful for defining terms directly in code or when loading
|
||||
from configuration files where only attributes are provided.
|
||||
|
||||
If optional fields (`name`, `label`, `definition`) are not provided,
|
||||
reasonable defaults are used (`key` for name/label, "Unknown" for
|
||||
definition).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key for the new term.
|
||||
name : str, optional
|
||||
The name for the new term (defaults to `key`).
|
||||
uri : str, optional
|
||||
The URI for the new term (optional).
|
||||
label : str, optional
|
||||
The display label for the new term (defaults to `key`).
|
||||
definition : str, optional
|
||||
The definition for the new term (defaults to "Unknown").
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The newly created and registered soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term with the provided key already exists.
|
||||
"""
|
||||
term = data.Term(
|
||||
name=name or key,
|
||||
label=label or key,
|
||||
uri=uri,
|
||||
definition=definition or "Unknown",
|
||||
)
|
||||
self.add_term(key, term)
|
||||
return term
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""Returns a list of all keys currently registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return list(self._terms.keys())
|
||||
|
||||
def get_terms(self) -> List[data.Term]:
|
||||
"""Returns a list of all registered terms.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[soundevent.data.Term]
|
||||
A list containing all registered Term objects.
|
||||
"""
|
||||
return list(self._terms.values())
|
||||
|
||||
def remove_key(self, key: str) -> None:
|
||||
del self._terms[key]
|
||||
|
||||
|
||||
default_term_registry = TermRegistry(
|
||||
terms=dict(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
("event", call_type),
|
||||
("species", terms.scientific_name),
|
||||
("individual", individual),
|
||||
("data_source", data_source),
|
||||
(GENERIC_CLASS_KEY, generic_class),
|
||||
]
|
||||
)
|
||||
)
|
||||
"""The default, globally accessible TermRegistry instance.
|
||||
|
||||
It is pre-populated with standard terms from `soundevent.terms` and common
|
||||
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
||||
Functions in this module use this registry by default unless another instance
|
||||
is explicitly passed.
|
||||
"""
|
||||
|
||||
|
||||
def get_term_from_key(
|
||||
key: str,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Term:
|
||||
"""Convenience function to retrieve a term by key from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique key of the term to retrieve.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to search in. Defaults to the global
|
||||
`registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the key is not found in the specified registry.
|
||||
"""
|
||||
term = terms.get_term(key)
|
||||
|
||||
if term:
|
||||
return term
|
||||
|
||||
term_registry = term_registry or default_term_registry
|
||||
return term_registry.get_term(key)
|
||||
|
||||
|
||||
def get_term_keys(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[str]:
|
||||
"""Convenience function to get all registered keys from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return term_registry.get_keys()
|
||||
|
||||
|
||||
def get_terms(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Term]:
|
||||
"""Convenience function to get all registered terms from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[soundevent.data.Term]
|
||||
A list containing all registered Term objects.
|
||||
"""
|
||||
return term_registry.get_terms()
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
"""Represents information needed to define a specific Tag.
|
||||
|
||||
This model is typically used in configuration files (e.g., YAML) to
|
||||
specify tags used for filtering, target class definition, or associating
|
||||
tags with output classes. It links a tag value to a term definition
|
||||
via the term's registry key.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
value : str
|
||||
The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
||||
key : str, default="class"
|
||||
The key (alias) of the term associated with this tag, as
|
||||
registered in the TermRegistry. Defaults to "class", implying
|
||||
it represents a classification target label by default.
|
||||
"""
|
||||
|
||||
value: str
|
||||
key: str = GENERIC_CLASS_KEY
|
||||
|
||||
|
||||
def get_tag_from_info(
|
||||
tag_info: TagInfo,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Tag:
|
||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||
|
||||
Looks up the term using the key in the provided `tag_info` from the
|
||||
specified registry and constructs a Tag object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag_info : TagInfo
|
||||
The TagInfo object containing the value and term key.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use for term lookup. Defaults to the
|
||||
global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Tag
|
||||
A soundevent.data.Tag object corresponding to the input info.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the term key specified in `tag_info.key` is not found
|
||||
in the registry.
|
||||
"""
|
||||
term_registry = term_registry or default_term_registry
|
||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||
return data.Tag(term=term, value=tag_info.value)
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
"""Represents the definition of a Term within a configuration file.
|
||||
|
||||
This model allows users to define custom terms directly in configuration
|
||||
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
||||
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
key : str
|
||||
The unique key (alias) that will be used to register and
|
||||
reference this term.
|
||||
label : str, optional
|
||||
The optional display label for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
name : str, optional
|
||||
The optional formal name for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
uri : str, optional
|
||||
The optional URI identifying the term (e.g., from a standard
|
||||
vocabulary).
|
||||
definition : str, optional
|
||||
The optional textual definition of the term. Defaults to
|
||||
"Unknown" if not provided during registration.
|
||||
"""
|
||||
|
||||
key: str
|
||||
label: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
uri: Optional[str] = None
|
||||
definition: Optional[str] = None
|
||||
|
||||
|
||||
class TermConfig(BaseModel):
|
||||
"""Pydantic schema for loading a list of term definitions from config.
|
||||
|
||||
This model typically corresponds to a section in a configuration file
|
||||
(e.g., YAML) containing a list of term definitions to be registered.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
terms : list[TermInfo]
|
||||
A list of TermInfo objects, each defining a term to be
|
||||
registered. Defaults to an empty list.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Example YAML structure:
|
||||
|
||||
```yaml
|
||||
terms:
|
||||
- key: species
|
||||
uri: dwc:scientificName
|
||||
label: Scientific Name
|
||||
- key: my_custom_term
|
||||
name: My Custom Term
|
||||
definition: Describes a specific project attribute.
|
||||
# ... more TermInfo definitions
|
||||
```
|
||||
"""
|
||||
|
||||
terms: List[TermInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
def load_terms_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> Dict[str, data.Term]:
|
||||
"""Loads term definitions from a configuration file and registers them.
|
||||
|
||||
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
||||
extracts the list of TermInfo definitions, and adds each one as a
|
||||
custom term to the specified TermRegistry instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
The path to the configuration file.
|
||||
field : str, optional
|
||||
Optional key indicating a specific section within the config
|
||||
file where the 'terms' list is located. If None, expects the
|
||||
list directly at the top level or within a structure matching
|
||||
TermConfig schema.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to add the loaded terms to. Defaults to
|
||||
the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, soundevent.data.Term]
|
||||
A dictionary mapping the keys of the newly added terms to their
|
||||
corresponding Term objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the TermConfig schema.
|
||||
KeyError
|
||||
If a term key loaded from the config conflicts with a key
|
||||
already present in the registry.
|
||||
"""
|
||||
data = load_config(path, schema=TermConfig, field=field)
|
||||
return {
|
||||
info.key: term_registry.add_custom_term(
|
||||
info.key,
|
||||
name=info.name,
|
||||
uri=info.uri,
|
||||
label=info.label,
|
||||
definition=info.definition,
|
||||
)
|
||||
for info in data.terms
|
||||
}
|
||||
|
||||
|
||||
def register_term(
|
||||
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
||||
) -> None:
|
||||
registry.add_term(key, term)
|
||||
|
||||
@ -1,708 +0,0 @@
|
||||
import importlib
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DerivationRegistry",
|
||||
"DeriveTagRule",
|
||||
"MapValueRule",
|
||||
"ReplaceRule",
|
||||
"SoundEventTransformation",
|
||||
"TransformConfig",
|
||||
"build_transform_from_rule",
|
||||
"build_transformation_from_config",
|
||||
"default_derivation_registry",
|
||||
"get_derivation",
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
"register_derivation",
|
||||
]
|
||||
|
||||
SoundEventTransformation = Callable[
|
||||
[data.SoundEventAnnotation], data.SoundEventAnnotation
|
||||
]
|
||||
"""Type alias for a sound event transformation function.
|
||||
|
||||
A function that accepts a sound event annotation object and returns a
|
||||
(potentially) modified sound event annotation object. Transformations
|
||||
should generally return a copy of the annotation rather than modifying
|
||||
it in place.
|
||||
"""
|
||||
|
||||
|
||||
Derivation = Callable[[str], str]
|
||||
"""Type alias for a derivation function.
|
||||
|
||||
A function that accepts a single string (typically a tag value) and returns
|
||||
a new string (the derived value).
|
||||
"""
|
||||
|
||||
|
||||
class MapValueRule(BaseConfig):
|
||||
"""Configuration for mapping specific values of a source term.
|
||||
|
||||
This rule replaces tags matching a specific term and one of the
|
||||
original values with a new tag (potentially having a different term)
|
||||
containing the corresponding replacement value. Useful for standardizing
|
||||
or grouping tag values.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rule_type : Literal["map_value"]
|
||||
Discriminator field identifying this rule type.
|
||||
source_term_key : str
|
||||
The key (registered in `TermRegistry`) of the term whose tags' values
|
||||
should be checked against the `value_mapping`.
|
||||
value_mapping : Dict[str, str]
|
||||
A dictionary mapping original string values to replacement string
|
||||
values. Only tags whose value is a key in this dictionary will be
|
||||
affected.
|
||||
target_term_key : str, optional
|
||||
The key (registered in `TermRegistry`) for the term of the *output*
|
||||
tag. If None (default), the output tag uses the same term as the
|
||||
source (`source_term_key`). If provided, the term of the affected
|
||||
tag is changed to this target term upon replacement.
|
||||
"""
|
||||
|
||||
rule_type: Literal["map_value"] = "map_value"
|
||||
source_term_key: str
|
||||
value_mapping: Dict[str, str]
|
||||
target_term_key: Optional[str] = None
|
||||
|
||||
|
||||
def map_value_transform(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
source_term: data.Term,
|
||||
target_term: data.Term,
|
||||
mapping: Dict[str, str],
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply a value mapping transformation to an annotation's tags.
|
||||
|
||||
Iterates through the annotation's tags. If a tag matches the `source_term`
|
||||
and its value is found in the `mapping`, it is replaced by a new tag with
|
||||
the `target_term` and the mapped value. Other tags are kept unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to transform.
|
||||
source_term : data.Term
|
||||
The term of tags whose values should be mapped.
|
||||
target_term : data.Term
|
||||
The term to use for the newly created tags after mapping.
|
||||
mapping : Dict[str, str]
|
||||
The dictionary mapping original values to new values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags.
|
||||
"""
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag.term != source_term or tag.value not in mapping:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
new_value = mapping[tag.value]
|
||||
tags.append(data.Tag(term=target_term, value=new_value))
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
|
||||
class DeriveTagRule(BaseConfig):
|
||||
"""Configuration for deriving a new tag from an existing tag's value.
|
||||
|
||||
This rule applies a specified function (`derivation_function`) to the
|
||||
value of tags matching the `source_term_key`. It then adds a new tag
|
||||
with the `target_term_key` and the derived value.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rule_type : Literal["derive_tag"]
|
||||
Discriminator field identifying this rule type.
|
||||
source_term_key : str
|
||||
The key (registered in `TermRegistry`) of the term whose tag values
|
||||
will be used as input to the derivation function.
|
||||
derivation_function : str
|
||||
The name/key identifying the derivation function to use. This can be
|
||||
a key registered in the `DerivationRegistry` or, if
|
||||
`import_derivation` is True, a full Python path like
|
||||
`'my_module.my_submodule.my_function'`.
|
||||
target_term_key : str, optional
|
||||
The key (registered in `TermRegistry`) for the term of the new tag
|
||||
that will be created with the derived value. If None (default), the
|
||||
derived tag uses the same term as the source (`source_term_key`),
|
||||
effectively performing an in-place value transformation.
|
||||
import_derivation : bool, default=False
|
||||
If True, treat `derivation_function` as a Python import path and
|
||||
attempt to dynamically import it if not found in the registry.
|
||||
Requires the function to be accessible in the Python environment.
|
||||
keep_source : bool, default=True
|
||||
If True, the original source tag (whose value was used for derivation)
|
||||
is kept in the annotation's tag list alongside the newly derived tag.
|
||||
If False, the original source tag is removed.
|
||||
"""
|
||||
|
||||
rule_type: Literal["derive_tag"] = "derive_tag"
|
||||
source_term_key: str
|
||||
derivation_function: str
|
||||
target_term_key: Optional[str] = None
|
||||
import_derivation: bool = False
|
||||
keep_source: bool = True
|
||||
|
||||
|
||||
def derivation_tag_transform(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
source_term: data.Term,
|
||||
target_term: data.Term,
|
||||
derivation: Derivation,
|
||||
keep_source: bool = True,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply a derivation transformation to an annotation's tags.
|
||||
|
||||
Iterates through the annotation's tags. For each tag matching the
|
||||
`source_term`, its value is passed to the `derivation` function.
|
||||
A new tag is created with the `target_term` and the derived value,
|
||||
and added to the output tag list. The original source tag is kept
|
||||
or discarded based on `keep_source`. Other tags are kept unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to transform.
|
||||
source_term : data.Term
|
||||
The term of tags whose values serve as input for the derivation.
|
||||
target_term : data.Term
|
||||
The term to use for the newly created derived tags.
|
||||
derivation : Derivation
|
||||
The function to apply to the source tag's value.
|
||||
keep_source : bool, default=True
|
||||
Whether to keep the original source tag in the output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags (including derived
|
||||
ones).
|
||||
"""
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag.term != source_term:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
if keep_source:
|
||||
tags.append(tag)
|
||||
|
||||
new_value = derivation(tag.value)
|
||||
tags.append(data.Tag(term=target_term, value=new_value))
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
|
||||
class ReplaceRule(BaseConfig):
|
||||
"""Configuration for exactly replacing one specific tag with another.
|
||||
|
||||
This rule looks for an exact match of the `original` tag (both term and
|
||||
value) and replaces it with the specified `replacement` tag.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rule_type : Literal["replace"]
|
||||
Discriminator field identifying this rule type.
|
||||
original : TagInfo
|
||||
The exact tag to search for, defined using its value and term key.
|
||||
replacement : TagInfo
|
||||
The tag to substitute in place of the original tag, defined using
|
||||
its value and term key.
|
||||
"""
|
||||
|
||||
rule_type: Literal["replace"] = "replace"
|
||||
original: TagInfo
|
||||
replacement: TagInfo
|
||||
|
||||
|
||||
def replace_tag_transform(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
source: data.Tag,
|
||||
target: data.Tag,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply an exact tag replacement transformation.
|
||||
|
||||
Iterates through the annotation's tags. If a tag exactly matches the
|
||||
`source` tag, it is replaced by the `target` tag. Other tags are kept
|
||||
unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to transform.
|
||||
source : data.Tag
|
||||
The exact tag to find and replace.
|
||||
target : data.Tag
|
||||
The tag to replace the source tag with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the replaced tag (if found).
|
||||
"""
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag == source:
|
||||
tags.append(target)
|
||||
else:
|
||||
tags.append(tag)
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
|
||||
class TransformConfig(BaseConfig):
|
||||
"""Configuration model for defining a sequence of transformation rules.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
|
||||
A list of transformation rules to apply. The rules are applied
|
||||
sequentially in the order they appear in the list. The output of
|
||||
one rule becomes the input for the next. The `rule_type` field
|
||||
discriminates between the different rule models.
|
||||
"""
|
||||
|
||||
rules: List[
|
||||
Annotated[
|
||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
Field(discriminator="rule_type"),
|
||||
]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
class DerivationRegistry(Mapping[str, Derivation]):
|
||||
"""A registry for managing named derivation functions.
|
||||
|
||||
Derivation functions are callables that take a string value and return
|
||||
a transformed string value, used by `DeriveTagRule`. This registry
|
||||
allows functions to be registered with a key and retrieved later.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize an empty DerivationRegistry."""
|
||||
self._derivations: Dict[str, Derivation] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> Derivation:
|
||||
"""Retrieve a derivation function by key."""
|
||||
return self._derivations[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of registered derivation functions."""
|
||||
return len(self._derivations)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the keys of registered functions."""
|
||||
return iter(self._derivations)
|
||||
|
||||
def register(self, key: str, derivation: Derivation) -> None:
|
||||
"""Register a derivation function with a unique key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique key to associate with the derivation function.
|
||||
derivation : Derivation
|
||||
The callable derivation function (takes str, returns str).
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a derivation function with the same key is already registered.
|
||||
"""
|
||||
if key in self._derivations:
|
||||
raise KeyError(
|
||||
f"A derivation with the provided key {key} already exists"
|
||||
)
|
||||
|
||||
self._derivations[key] = derivation
|
||||
|
||||
def get_derivation(self, key: str) -> Derivation:
|
||||
"""Retrieve a derivation function by its registered key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key of the derivation function to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Derivation
|
||||
The requested derivation function.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If no derivation function with the specified key is registered.
|
||||
"""
|
||||
try:
|
||||
return self._derivations[key]
|
||||
except KeyError as err:
|
||||
raise KeyError(
|
||||
f"No derivation with key {key} is registered."
|
||||
) from err
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""Get a list of all registered derivation function keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
The keys of all registered functions.
|
||||
"""
|
||||
return list(self._derivations.keys())
|
||||
|
||||
def get_derivations(self) -> List[Derivation]:
|
||||
"""Get a list of all registered derivation functions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Derivation]
|
||||
The registered derivation function objects.
|
||||
"""
|
||||
return list(self._derivations.values())
|
||||
|
||||
|
||||
default_derivation_registry = DerivationRegistry()
|
||||
"""Global instance of the DerivationRegistry.
|
||||
|
||||
Register custom derivation functions here to make them available by key
|
||||
in `DeriveTagRule` configuration.
|
||||
"""
|
||||
|
||||
|
||||
def get_derivation(
|
||||
key: str,
|
||||
import_derivation: bool = False,
|
||||
registry: Optional[DerivationRegistry] = None,
|
||||
):
|
||||
"""Retrieve a derivation function by key, optionally importing it.
|
||||
|
||||
First attempts to find the function in the provided `registry`.
|
||||
If not found and `import_derivation` is True, attempts to dynamically
|
||||
import the function using the `key` as a full Python path
|
||||
(e.g., 'my_module.submodule.my_func').
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key or Python path of the derivation function.
|
||||
import_derivation : bool, default=False
|
||||
If True, attempt dynamic import if key is not in the registry.
|
||||
registry : DerivationRegistry, optional
|
||||
The registry instance to check first. Defaults to the global
|
||||
`derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Derivation
|
||||
The requested derivation function.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the key is not found in the registry and either
|
||||
`import_derivation` is False or the dynamic import fails.
|
||||
ImportError
|
||||
If dynamic import fails specifically due to module not found.
|
||||
AttributeError
|
||||
If dynamic import fails because the function name isn't in the module.
|
||||
"""
|
||||
registry = registry or default_derivation_registry
|
||||
|
||||
if not import_derivation or key in registry:
|
||||
return registry.get_derivation(key)
|
||||
|
||||
try:
|
||||
module_path, func_name = key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
func = getattr(module, func_name)
|
||||
return func
|
||||
except ImportError as err:
|
||||
raise KeyError(
|
||||
f"Unable to load derivation '{key}'. Check the path and ensure "
|
||||
"it points to a valid callable function in an importable module."
|
||||
) from err
|
||||
|
||||
|
||||
TranformationRule = Annotated[
|
||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
Field(discriminator="rule_type"),
|
||||
]
|
||||
|
||||
|
||||
def build_transform_from_rule(
|
||||
rule: TranformationRule,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a specific SoundEventTransformation function from a rule config.
|
||||
|
||||
Selects the appropriate transformation logic based on the rule's
|
||||
`rule_type`, fetches necessary terms and derivation functions, and
|
||||
returns a partially applied function ready to transform an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
|
||||
The configuration object for a single transformation rule.
|
||||
registry : DerivationRegistry, optional
|
||||
The derivation registry to use for `DeriveTagRule`. Defaults to the
|
||||
global `derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventTransformation
|
||||
A callable that applies the specified rule to a SoundEventAnnotation.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If required term keys or derivation keys are not found.
|
||||
ValueError
|
||||
If the rule has an unknown `rule_type`.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function fails.
|
||||
"""
|
||||
if rule.rule_type == "replace":
|
||||
source = get_tag_from_info(
|
||||
rule.original,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target = get_tag_from_info(
|
||||
rule.replacement,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
return partial(replace_tag_transform, source=source, target=target)
|
||||
|
||||
if rule.rule_type == "derive_tag":
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target_term = (
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
derivation = get_derivation(
|
||||
key=rule.derivation_function,
|
||||
import_derivation=rule.import_derivation,
|
||||
registry=derivation_registry,
|
||||
)
|
||||
return partial(
|
||||
derivation_tag_transform,
|
||||
source_term=source_term,
|
||||
target_term=target_term,
|
||||
derivation=derivation,
|
||||
keep_source=rule.keep_source,
|
||||
)
|
||||
|
||||
if rule.rule_type == "map_value":
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target_term = (
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
return partial(
|
||||
map_value_transform,
|
||||
source_term=source_term,
|
||||
target_term=target_term,
|
||||
mapping=rule.value_mapping,
|
||||
)
|
||||
|
||||
# Handle unknown rule type
|
||||
valid_options = ["replace", "derive_tag", "map_value"]
|
||||
# Should be caught by Pydantic validation, but good practice
|
||||
raise ValueError(
|
||||
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
||||
f"Valid options are: {valid_options}"
|
||||
)
|
||||
|
||||
|
||||
def build_transformation_from_config(
|
||||
config: TransformConfig,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a composite transformation function from a TransformConfig.
|
||||
|
||||
Creates a sequence of individual transformation functions based on the
|
||||
rules defined in the configuration. Returns a single function that
|
||||
applies these transformations sequentially to an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TransformConfig
|
||||
The configuration object containing the list of transformation rules.
|
||||
derivation_reg : DerivationRegistry, optional
|
||||
The derivation registry to use when building `DeriveTagRule`
|
||||
transformations. Defaults to the global `derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventTransformation
|
||||
A single function that applies all configured transformations in order.
|
||||
"""
|
||||
|
||||
transforms = [
|
||||
build_transform_from_rule(
|
||||
rule,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
for rule in config.rules
|
||||
]
|
||||
|
||||
return partial(apply_sequence_of_transforms, transforms=transforms)
|
||||
|
||||
|
||||
def apply_sequence_of_transforms(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
transforms: list[SoundEventTransformation],
|
||||
) -> data.SoundEventAnnotation:
|
||||
for transform in transforms:
|
||||
sound_event_annotation = transform(sound_event_annotation)
|
||||
return sound_event_annotation
|
||||
|
||||
|
||||
def load_transformation_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> TransformConfig:
|
||||
"""Load the transformation configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : str, optional
|
||||
If the transformation configuration is nested under a specific key
|
||||
in the file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TransformConfig
|
||||
The loaded and validated transformation configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the TransformConfig schema.
|
||||
"""
|
||||
return load_config(path=path, schema=TransformConfig, field=field)
|
||||
|
||||
|
||||
def load_transformation_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Load transformation config from a file and build the final function.
|
||||
|
||||
This is a convenience function that combines loading the configuration
|
||||
and building the final callable transformation function that applies
|
||||
all rules sequentially.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : str, optional
|
||||
If the transformation configuration is nested under a specific key
|
||||
in the file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventTransformation
|
||||
The final composite transformation function ready to be used.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the TransformConfig schema.
|
||||
KeyError
|
||||
If required term keys or derivation keys specified in the config
|
||||
are not found during the build process.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function specified in the config
|
||||
fails.
|
||||
"""
|
||||
config = load_transformation_config(path=path, field=field)
|
||||
return build_transformation_from_config(
|
||||
config,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
|
||||
|
||||
def register_derivation(
|
||||
key: str,
|
||||
derivation: Derivation,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
) -> None:
|
||||
"""Register a new derivation function in the global registry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique key to associate with the derivation function.
|
||||
derivation : Derivation
|
||||
The callable derivation function (takes str, returns str).
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The registry instance to register the derivation function with.
|
||||
Defaults to the global `derivation_registry`.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a derivation function with the same key is already registered.
|
||||
"""
|
||||
derivation_registry = derivation_registry or default_derivation_registry
|
||||
derivation_registry.register(key, derivation)
|
||||
@ -33,10 +33,6 @@ from batdetect2.train.losses import (
|
||||
SizeLossConfig,
|
||||
build_loss,
|
||||
)
|
||||
from batdetect2.train.preprocess import (
|
||||
generate_train_example,
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.train.train import (
|
||||
build_train_dataset,
|
||||
build_train_loader,
|
||||
@ -74,14 +70,12 @@ __all__ = [
|
||||
"build_trainer",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"generate_train_example",
|
||||
"load_full_training_config",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_audio",
|
||||
"preprocess_annotations",
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
|
||||
@ -44,7 +44,7 @@ AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
||||
class MixAugmentationConfig(BaseConfig):
|
||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||
|
||||
augmentation_type: Literal["mix_audio"] = "mix_audio"
|
||||
name: Literal["mix_audio"] = "mix_audio"
|
||||
|
||||
probability: float = 0.2
|
||||
"""Probability of applying this augmentation to an example."""
|
||||
@ -140,7 +140,7 @@ def combine_clip_annotations(
|
||||
class EchoAugmentationConfig(BaseConfig):
|
||||
"""Configuration for adding synthetic echo/reverb."""
|
||||
|
||||
augmentation_type: Literal["add_echo"] = "add_echo"
|
||||
name: Literal["add_echo"] = "add_echo"
|
||||
probability: float = 0.2
|
||||
max_delay: float = 0.005
|
||||
min_weight: float = 0.0
|
||||
@ -187,7 +187,7 @@ def add_echo(
|
||||
class VolumeAugmentationConfig(BaseConfig):
|
||||
"""Configuration for random volume scaling of the spectrogram."""
|
||||
|
||||
augmentation_type: Literal["scale_volume"] = "scale_volume"
|
||||
name: Literal["scale_volume"] = "scale_volume"
|
||||
probability: float = 0.2
|
||||
min_scaling: float = 0.0
|
||||
max_scaling: float = 2.0
|
||||
@ -214,7 +214,7 @@ def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
|
||||
|
||||
|
||||
class WarpAugmentationConfig(BaseConfig):
|
||||
augmentation_type: Literal["warp"] = "warp"
|
||||
name: Literal["warp"] = "warp"
|
||||
probability: float = 0.2
|
||||
delta: float = 0.04
|
||||
|
||||
@ -296,7 +296,7 @@ def warp_spectrogram(
|
||||
|
||||
|
||||
class TimeMaskAugmentationConfig(BaseConfig):
|
||||
augmentation_type: Literal["mask_time"] = "mask_time"
|
||||
name: Literal["mask_time"] = "mask_time"
|
||||
probability: float = 0.2
|
||||
max_perc: float = 0.05
|
||||
max_masks: int = 3
|
||||
@ -353,7 +353,7 @@ def mask_time(
|
||||
|
||||
|
||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||
augmentation_type: Literal["mask_freq"] = "mask_freq"
|
||||
name: Literal["mask_freq"] = "mask_freq"
|
||||
probability: float = 0.2
|
||||
max_perc: float = 0.10
|
||||
max_masks: int = 3
|
||||
@ -414,7 +414,7 @@ AudioAugmentationConfig = Annotated[
|
||||
MixAugmentationConfig,
|
||||
EchoAugmentationConfig,
|
||||
],
|
||||
Field(discriminator="augmentation_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -425,7 +425,7 @@ SpectrogramAugmentationConfig = Annotated[
|
||||
FrequencyMaskAugmentationConfig,
|
||||
TimeMaskAugmentationConfig,
|
||||
],
|
||||
Field(discriminator="augmentation_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
AugmentationConfig = Annotated[
|
||||
@ -437,7 +437,7 @@ AugmentationConfig = Annotated[
|
||||
FrequencyMaskAugmentationConfig,
|
||||
TimeMaskAugmentationConfig,
|
||||
],
|
||||
Field(discriminator="augmentation_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of individual augmentation config."""
|
||||
|
||||
@ -485,7 +485,7 @@ def build_augmentation_from_config(
|
||||
audio_source: Optional[AudioSource] = None,
|
||||
) -> Optional[Augmentation]:
|
||||
"""Factory function to build a single augmentation from its config."""
|
||||
if config.augmentation_type == "mix_audio":
|
||||
if config.name == "mix_audio":
|
||||
if audio_source is None:
|
||||
warnings.warn(
|
||||
"Mix audio augmentation ('mix_audio') requires an "
|
||||
@ -500,31 +500,31 @@ def build_augmentation_from_config(
|
||||
max_weight=config.max_weight,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "add_echo":
|
||||
if config.name == "add_echo":
|
||||
return AddEcho(
|
||||
max_delay=int(config.max_delay * samplerate),
|
||||
min_weight=config.min_weight,
|
||||
max_weight=config.max_weight,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "scale_volume":
|
||||
if config.name == "scale_volume":
|
||||
return ScaleVolume(
|
||||
max_scaling=config.max_scaling,
|
||||
min_scaling=config.min_scaling,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "warp":
|
||||
if config.name == "warp":
|
||||
return WarpSpectrogram(
|
||||
delta=config.delta,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "mask_time":
|
||||
if config.name == "mask_time":
|
||||
return MaskTime(
|
||||
max_perc=config.max_perc,
|
||||
max_masks=config.max_masks,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "mask_freq":
|
||||
if config.name == "mask_freq":
|
||||
return MaskFrequency(
|
||||
max_perc=config.max_perc,
|
||||
max_masks=config.max_masks,
|
||||
|
||||
@ -6,7 +6,6 @@ from soundevent import data
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig):
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
loss: torch.nn.Module,
|
||||
config: FullTrainingConfig,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
model: Optional[Model] = None,
|
||||
loss: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
self.config = config
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
if loss is None:
|
||||
loss = build_loss(self.config.train.loss)
|
||||
|
||||
if model is None:
|
||||
model = build_model(self.config)
|
||||
|
||||
self.loss = loss
|
||||
self.model = model
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.model(spec)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.model.detector(batch.spec)
|
||||
@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule):
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> Tuple[Model, FullTrainingConfig]:
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
return module.model, module.config
|
||||
|
||||
@ -5,10 +5,11 @@ import numpy as np
|
||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: str = "logs"
|
||||
DEFAULT_LOGS_DIR: str = "outputs"
|
||||
|
||||
|
||||
class DVCLiveConfig(BaseConfig):
|
||||
@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig):
|
||||
class TensorBoardLoggerConfig(BaseConfig):
|
||||
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||
save_dir: str = DEFAULT_LOGS_DIR
|
||||
name: Optional[str] = "default"
|
||||
name: Optional[str] = "logs"
|
||||
version: Optional[str] = None
|
||||
log_graph: bool = False
|
||||
|
||||
@ -57,7 +58,10 @@ LoggerConfig = Annotated[
|
||||
]
|
||||
|
||||
|
||||
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||
def create_dvclive_logger(
|
||||
config: DVCLiveConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||
except ImportError as error:
|
||||
@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||
) from error
|
||||
|
||||
return DVCLiveLogger(
|
||||
dir=config.dir,
|
||||
dir=log_dir if log_dir is not None else config.dir,
|
||||
run_name=config.run_name,
|
||||
prefix=config.prefix,
|
||||
log_model=config.log_model,
|
||||
@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||
)
|
||||
|
||||
|
||||
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
||||
def create_csv_logger(
|
||||
config: CSVLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
return CSVLogger(
|
||||
save_dir=config.save_dir,
|
||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||
name=config.name,
|
||||
version=config.version,
|
||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||
)
|
||||
|
||||
|
||||
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
||||
def create_tensorboard_logger(
|
||||
config: TensorBoardLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
return TensorBoardLogger(
|
||||
save_dir=config.save_dir,
|
||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||
name=config.name,
|
||||
version=config.version,
|
||||
log_graph=config.log_graph,
|
||||
)
|
||||
|
||||
|
||||
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
||||
def create_mlflow_logger(
|
||||
config: MLFlowLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
except ImportError as error:
|
||||
@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
||||
return MLFlowLogger(
|
||||
experiment_name=config.experiment_name,
|
||||
run_name=config.run_name,
|
||||
save_dir=config.save_dir,
|
||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||
tracking_uri=config.tracking_uri,
|
||||
tags=config.tags,
|
||||
log_model=config.log_model,
|
||||
@ -126,7 +139,10 @@ LOGGER_FACTORY = {
|
||||
}
|
||||
|
||||
|
||||
def build_logger(config: LoggerConfig) -> Logger:
|
||||
def build_logger(
|
||||
config: LoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
"""
|
||||
Creates a logger instance from a validated Pydantic config object.
|
||||
"""
|
||||
@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger:
|
||||
|
||||
creation_func = LOGGER_FACTORY[logger_type]
|
||||
|
||||
return creation_func(config)
|
||||
return creation_func(config, log_dir=log_dir)
|
||||
|
||||
|
||||
def get_image_plotter(logger: Logger):
|
||||
|
||||
@ -1,243 +0,0 @@
|
||||
"""Preprocesses datasets for BatDetect2 model training."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Sequence, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.typing.train import PreprocessedExample
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"generate_train_example",
|
||||
"preprocess_dataset",
|
||||
"TrainPreprocessConfig",
|
||||
"load_train_preprocessing_config",
|
||||
"save_preprocessed_example",
|
||||
"load_preprocessed_example",
|
||||
]
|
||||
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
"""Type alias for a function that generates an output filename."""
|
||||
|
||||
|
||||
class TrainPreprocessConfig(BaseConfig):
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
def load_train_preprocessing_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TrainPreprocessConfig:
|
||||
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Dataset,
|
||||
config: TrainPreprocessConfig,
|
||||
output: Path,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
targets = build_targets(config=config.targets)
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
labeller = build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.labels,
|
||||
)
|
||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||
|
||||
if not output.exists():
|
||||
logger.debug("Creating directory {directory}", directory=output)
|
||||
output.mkdir(parents=True)
|
||||
|
||||
preprocess_annotations(
|
||||
dataset,
|
||||
output_dir=output,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
|
||||
class Example(TypedDict):
|
||||
audio: torch.Tensor
|
||||
spectrogram: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
) -> PreprocessedExample:
|
||||
"""Generate a complete training example for one annotation."""
|
||||
wave = torch.tensor(
|
||||
audio_loader.load_clip(clip_annotation.clip)
|
||||
).unsqueeze(0)
|
||||
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
|
||||
heatmaps = labeller(clip_annotation, spectrogram)
|
||||
return PreprocessedExample(
|
||||
audio=wave,
|
||||
spectrogram=spectrogram,
|
||||
detection_heatmap=heatmaps.detection,
|
||||
class_heatmap=heatmaps.classes,
|
||||
size_heatmap=heatmaps.size,
|
||||
)
|
||||
|
||||
|
||||
class PreprocessingDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
clips: Dataset,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn,
|
||||
output_dir: Path,
|
||||
force: bool = False,
|
||||
):
|
||||
self.clips = clips
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.labeller = labeller
|
||||
self.filename_fn = filename_fn
|
||||
self.output_dir = output_dir
|
||||
self.force = force
|
||||
|
||||
def __getitem__(self, idx) -> int:
|
||||
clip_annotation = self.clips[idx]
|
||||
|
||||
filename = self.filename_fn(clip_annotation)
|
||||
|
||||
path = self.output_dir / filename
|
||||
|
||||
if path.exists() and not self.force:
|
||||
return idx
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir()
|
||||
|
||||
example = generate_train_example(
|
||||
clip_annotation,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
labeller=self.labeller,
|
||||
)
|
||||
|
||||
save_preprocessed_example(example, clip_annotation, path)
|
||||
|
||||
return idx
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
|
||||
|
||||
def save_preprocessed_example(
|
||||
example: PreprocessedExample,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
path: data.PathLike,
|
||||
) -> None:
|
||||
np.savez_compressed(
|
||||
path,
|
||||
audio=example.audio.numpy(),
|
||||
spectrogram=example.spectrogram.numpy(),
|
||||
detection_heatmap=example.detection_heatmap.numpy(),
|
||||
class_heatmap=example.class_heatmap.numpy(),
|
||||
size_heatmap=example.size_heatmap.numpy(),
|
||||
clip_annotation=clip_annotation,
|
||||
)
|
||||
|
||||
|
||||
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||
item = np.load(path, mmap_mode="r+")
|
||||
return PreprocessedExample(
|
||||
audio=torch.tensor(item["audio"]),
|
||||
spectrogram=torch.tensor(item["spectrogram"]),
|
||||
size_heatmap=torch.tensor(item["size_heatmap"]),
|
||||
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
||||
class_heatmap=torch.tensor(item["class_heatmap"]),
|
||||
)
|
||||
|
||||
|
||||
def list_preprocessed_files(
|
||||
directory: data.PathLike, extension: str = ".npz"
|
||||
) -> List[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
"""Generate a default output filename based on the annotation UUID."""
|
||||
return f"{clip_annotation.uuid}"
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: data.PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
logger.info(
|
||||
"Creating output directory: {output_dir}", output_dir=output_dir
|
||||
)
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
logger.info(
|
||||
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
|
||||
num_annotations=len(clip_annotations),
|
||||
max_workers=max_workers or "all available",
|
||||
)
|
||||
|
||||
if max_workers is None:
|
||||
max_workers = os.cpu_count() or 0
|
||||
|
||||
dataset = PreprocessingDataset(
|
||||
clips=list(clip_annotations),
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
output_dir=Path(output_dir),
|
||||
filename_fn=filename_fn,
|
||||
)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=max_workers,
|
||||
prefetch_factor=16,
|
||||
)
|
||||
|
||||
for _ in tqdm(loader, total=len(dataset)):
|
||||
pass
|
||||
@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.augmentations import (
|
||||
RandomAudioSource,
|
||||
build_augmentations,
|
||||
@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import (
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
@ -54,19 +53,21 @@ def train(
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[data.PathLike] = None,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
config = config or FullTrainingConfig()
|
||||
|
||||
model = build_model(config=config)
|
||||
targets = build_targets(config.targets)
|
||||
|
||||
trainer = build_trainer(config, targets=model.targets)
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||
|
||||
labeller = build_clip_labeler(
|
||||
model.targets,
|
||||
min_freq=model.preprocessor.min_freq,
|
||||
max_freq=model.preprocessor.max_freq,
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.train.labels,
|
||||
)
|
||||
|
||||
@ -74,7 +75,7 @@ def train(
|
||||
train_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=build_preprocessor(config.preprocess),
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
@ -84,7 +85,7 @@ def train(
|
||||
val_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=build_preprocessor(config.preprocess),
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
@ -97,11 +98,17 @@ def train(
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = build_training_module(
|
||||
model,
|
||||
config,
|
||||
batches_per_epoch=len(train_dataloader),
|
||||
t_max=config.train.t_max * len(train_dataloader),
|
||||
)
|
||||
|
||||
trainer = build_trainer(
|
||||
config,
|
||||
targets=targets,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
)
|
||||
|
||||
logger.info("Starting main training loop...")
|
||||
trainer.fit(
|
||||
module,
|
||||
@ -112,16 +119,14 @@ def train(
|
||||
|
||||
|
||||
def build_training_module(
|
||||
model: Model,
|
||||
config: FullTrainingConfig,
|
||||
batches_per_epoch: int,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
t_max: int = 200,
|
||||
) -> TrainingModule:
|
||||
loss = build_loss(config=config.train.loss)
|
||||
config = config or FullTrainingConfig()
|
||||
return TrainingModule(
|
||||
model=model,
|
||||
loss=loss,
|
||||
config=config,
|
||||
learning_rate=config.train.learning_rate,
|
||||
t_max=config.train.t_max * batches_per_epoch,
|
||||
t_max=t_max,
|
||||
)
|
||||
|
||||
|
||||
@ -129,10 +134,14 @@ def build_trainer_callbacks(
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: EvaluationConfig,
|
||||
checkpoint_dir: Optional[data.PathLike] = None,
|
||||
) -> List[Callback]:
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = "outputs/checkpoints"
|
||||
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath="outputs/checkpoints",
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=1,
|
||||
monitor="total_loss/val",
|
||||
),
|
||||
@ -153,15 +162,22 @@ def build_trainer_callbacks(
|
||||
def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
checkpoint_dir: Optional[data.PathLike] = None,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = conf.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
train_logger = build_logger(conf.train.logger)
|
||||
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
|
||||
|
||||
train_logger.log_hyperparams(conf.model_dump(mode="json"))
|
||||
train_logger.log_hyperparams(
|
||||
conf.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
)
|
||||
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
@ -170,6 +186,7 @@ def build_trainer(
|
||||
targets,
|
||||
config=conf.evaluation,
|
||||
preprocessor=build_preprocessor(conf.preprocess),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -67,8 +67,7 @@ class TargetProtocol(Protocol):
|
||||
This protocol outlines the standard attributes and methods for an object
|
||||
that encapsulates the complete, configured process for handling sound event
|
||||
annotations (both tags and geometry). It defines how to:
|
||||
- Filter relevant annotations.
|
||||
- Transform annotation tags.
|
||||
- Select relevant annotations.
|
||||
- Encode an annotation into a specific target class name.
|
||||
- Decode a class name back into representative tags.
|
||||
- Extract a target reference position from an annotation's geometry (ROI).
|
||||
@ -121,26 +120,6 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def transform(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply tag transformations to an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation whose tags should be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags. Implementations
|
||||
should return the original annotation object if no transformations
|
||||
were configured.
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
@ -248,3 +227,70 @@ class TargetProtocol(Protocol):
|
||||
if reconstruction fails based on the configured position type.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
Specifies the `encode` and `decode` methods required for converting a
|
||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||
position and a size vector) and for recovering an approximate ROI from that
|
||||
representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the dimensions in the `Size` array
|
||||
returned by `encode` and expected by `decode`.
|
||||
"""
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||
"""Encode a SoundEvent's geometry into a position and size.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event, which must have a geometry attribute.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
A tuple containing:
|
||||
- The reference position as (time, frequency) coordinates.
|
||||
- A NumPy array with the calculated size dimensions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the sound event does not have a geometry.
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Decode a position and size back into a geometric ROI.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and size
|
||||
dimensions and reconstructs a geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Position
|
||||
The reference position (time, frequency).
|
||||
size : Size
|
||||
NumPy array containing the size dimensions, matching the order
|
||||
and meaning specified by `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry, typically a `BoundingBox`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the `size` array has an unexpected shape or if reconstruction
|
||||
fails.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -15,13 +15,10 @@ from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
TermRegistry,
|
||||
build_targets,
|
||||
call_type,
|
||||
)
|
||||
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
||||
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
from batdetect2.targets.classes import TargetClassConfig
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.typing import (
|
||||
@ -355,18 +352,6 @@ def create_annotation_project():
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_term_registry() -> TermRegistry:
|
||||
"""Fixture for a sample TermRegistry."""
|
||||
registry = TermRegistry()
|
||||
registry.add_custom_term("class")
|
||||
registry.add_custom_term("order")
|
||||
registry.add_custom_term("species")
|
||||
registry.add_custom_term("call_type")
|
||||
registry.add_custom_term("quality")
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_preprocessor() -> PreprocessorProtocol:
|
||||
return build_preprocessor()
|
||||
@ -378,56 +363,45 @@ def sample_audio_loader() -> AudioLoader:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bat_tag() -> TagInfo:
|
||||
return TagInfo(key="class", value="bat")
|
||||
def bat_tag() -> data.Tag:
|
||||
return data.Tag(key="class", value="bat")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_tag() -> TagInfo:
|
||||
return TagInfo(key="class", value="noise")
|
||||
def noise_tag() -> data.Tag:
|
||||
return data.Tag(key="class", value="noise")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def myomyo_tag() -> TagInfo:
|
||||
return TagInfo(key="species", value="Myotis myotis")
|
||||
def myomyo_tag() -> data.Tag:
|
||||
return data.Tag(key="species", value="Myotis myotis")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pippip_tag() -> TagInfo:
|
||||
return TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
def pippip_tag() -> data.Tag:
|
||||
return data.Tag(key="species", value="Pipistrellus pipistrellus")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_target_config(
|
||||
sample_term_registry: TermRegistry,
|
||||
bat_tag: TagInfo,
|
||||
noise_tag: TagInfo,
|
||||
myomyo_tag: TagInfo,
|
||||
pippip_tag: TagInfo,
|
||||
bat_tag: data.Tag,
|
||||
myomyo_tag: data.Tag,
|
||||
pippip_tag: data.Tag,
|
||||
) -> TargetConfig:
|
||||
return TargetConfig(
|
||||
filtering=FilterConfig(
|
||||
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
|
||||
),
|
||||
classes=ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(name="pippip", tags=[pippip_tag]),
|
||||
TargetClass(name="myomyo", tags=[myomyo_tag]),
|
||||
],
|
||||
generic_class=[bat_tag],
|
||||
),
|
||||
detection_target=TargetClassConfig(name="bat", tags=[bat_tag]),
|
||||
classification_targets=[
|
||||
TargetClassConfig(name="pippip", tags=[pippip_tag]),
|
||||
TargetClassConfig(name="myomyo", tags=[myomyo_tag]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_targets(
|
||||
sample_target_config: TargetConfig,
|
||||
sample_term_registry: TermRegistry,
|
||||
) -> TargetProtocol:
|
||||
return build_targets(
|
||||
sample_target_config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
return build_targets(sample_target_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -443,10 +417,8 @@ def sample_labeller(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_clipper(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
) -> ClipperProtocol:
|
||||
return build_clipper(preprocessor=sample_preprocessor)
|
||||
def sample_clipper() -> ClipperProtocol:
|
||||
return build_clipper()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
0
tests/test_data/test_transforms/__init__.py
Normal file
0
tests/test_data/test_transforms/__init__.py
Normal file
516
tests/test_data/test_transforms/test_conditions.py
Normal file
516
tests/test_data/test_transforms/test_conditions.py
Normal file
@ -0,0 +1,516 @@
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
|
||||
|
||||
def build_condition_from_str(content):
|
||||
content = textwrap.dedent(content)
|
||||
content = yaml.safe_load(content)
|
||||
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
||||
return build_sound_event_condition(config)
|
||||
|
||||
|
||||
def test_has_tag(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_has_all_tags(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: has_all_tags
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
- key: event
|
||||
value: Echolocation
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
data.Tag(key="sex", value="Female"), # type: ignore
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_has_any_tags(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: has_any_tag
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
- key: event
|
||||
value: Echolocation
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||
data.Tag(key="event", value="Social"), # type: ignore
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_not(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_duration(recording: data.Recording):
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
|
||||
),
|
||||
)
|
||||
se2 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
|
||||
),
|
||||
)
|
||||
se3 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
|
||||
),
|
||||
)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: lt
|
||||
seconds: 2
|
||||
""")
|
||||
assert condition(se1)
|
||||
assert not condition(se2)
|
||||
assert not condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: lte
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert condition(se1)
|
||||
assert condition(se2)
|
||||
assert not condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: gt
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert not condition(se1)
|
||||
assert not condition(se2)
|
||||
assert condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: gte
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert not condition(se1)
|
||||
assert condition(se2)
|
||||
assert condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: eq
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert not condition(se1)
|
||||
assert condition(se2)
|
||||
assert not condition(se3)
|
||||
|
||||
|
||||
def test_frequency(recording: data.Recording):
|
||||
se12 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
|
||||
),
|
||||
)
|
||||
se13 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
|
||||
),
|
||||
)
|
||||
se14 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
|
||||
),
|
||||
)
|
||||
se24 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
|
||||
),
|
||||
)
|
||||
se34 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
|
||||
),
|
||||
)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: lt
|
||||
hertz: 300
|
||||
""")
|
||||
assert condition(se12)
|
||||
assert not condition(se13)
|
||||
assert not condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: lte
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert condition(se12)
|
||||
assert condition(se13)
|
||||
assert not condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: gt
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert not condition(se12)
|
||||
assert not condition(se13)
|
||||
assert condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: gte
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert not condition(se12)
|
||||
assert condition(se13)
|
||||
assert condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: eq
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert not condition(se12)
|
||||
assert condition(se13)
|
||||
assert not condition(se14)
|
||||
|
||||
# LOW
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: lt
|
||||
hertz: 200
|
||||
""")
|
||||
assert condition(se14)
|
||||
assert not condition(se24)
|
||||
assert not condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: lte
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert condition(se14)
|
||||
assert condition(se24)
|
||||
assert not condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: gt
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert not condition(se14)
|
||||
assert not condition(se24)
|
||||
assert condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: gte
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert not condition(se14)
|
||||
assert condition(se24)
|
||||
assert condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: eq
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert not condition(se14)
|
||||
assert condition(se24)
|
||||
assert not condition(se34)
|
||||
|
||||
|
||||
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: eq
|
||||
hertz: 200
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 3]),
|
||||
recording=recording,
|
||||
)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeStamp(coordinates=3),
|
||||
recording=recording,
|
||||
)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_has_tags_fails_if_empty():
|
||||
with pytest.raises(ValueError):
|
||||
build_condition_from_str("""
|
||||
name: has_tags
|
||||
tags: []
|
||||
""")
|
||||
|
||||
|
||||
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: eq
|
||||
hertz: 200
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_duration_is_false_if_no_geometry(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: eq
|
||||
seconds: 1
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_all_of(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
- name: duration
|
||||
operator: lt
|
||||
seconds: 1
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_any_of(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: any_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
- name: duration
|
||||
operator: lt
|
||||
seconds: 1
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||
)
|
||||
assert condition(se)
|
||||
@ -3,26 +3,15 @@ from typing import Callable
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_SPECIES_LIST,
|
||||
ClassesConfig,
|
||||
TargetClass,
|
||||
_get_default_class_name,
|
||||
_get_default_classes,
|
||||
build_generic_class_tags,
|
||||
TargetClassConfig,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
is_target_class,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
load_encoder_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -33,8 +22,8 @@ def sample_annotation(
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"),
|
||||
data.Tag(key="quality", value="Good"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -51,291 +40,71 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
||||
return factory
|
||||
|
||||
|
||||
def test_target_class_creation():
|
||||
target_class = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
assert target_class.name == "pippip"
|
||||
assert target_class.tags[0].key == "species"
|
||||
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
|
||||
assert target_class.match_type == "all"
|
||||
|
||||
|
||||
def test_classes_config_creation():
|
||||
target_class = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
config = ClassesConfig(classes=[target_class])
|
||||
assert len(config.classes) == 1
|
||||
assert config.classes[0].name == "pippip"
|
||||
|
||||
|
||||
def test_classes_config_unique_names():
|
||||
target_class1 = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
target_class2 = TargetClass(
|
||||
name="myodau",
|
||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||
)
|
||||
ClassesConfig(classes=[target_class1, target_class2]) # No error
|
||||
|
||||
|
||||
def test_classes_config_non_unique_names():
|
||||
target_class1 = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
target_class2 = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
ClassesConfig(classes=[target_class1, target_class2])
|
||||
|
||||
|
||||
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
config = load_classes_config(temp_yaml_path)
|
||||
assert len(config.classes) == 1
|
||||
assert config.classes[0].name == "pippip"
|
||||
|
||||
|
||||
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis daubentonii
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
with pytest.raises(ValidationError):
|
||||
load_classes_config(temp_yaml_path)
|
||||
|
||||
|
||||
def test_is_target_class_match_all(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||
|
||||
|
||||
def test_is_target_class_match_any(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||
|
||||
|
||||
def test_get_class_names_from_config():
|
||||
target_class1 = TargetClass(
|
||||
target_class1 = TargetClassConfig(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
target_class2 = TargetClass(
|
||||
target_class2 = TargetClassConfig(
|
||||
name="myodau",
|
||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||
tags=[data.Tag(key="species", value="Myotis daubentonii")],
|
||||
)
|
||||
config = ClassesConfig(classes=[target_class1, target_class2])
|
||||
names = get_class_names_from_config(config)
|
||||
names = get_class_names_from_config([target_class1, target_class2])
|
||||
assert names == ["pippip", "myodau"]
|
||||
|
||||
|
||||
def test_build_encoder_from_config(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
):
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
encoder = build_sound_event_encoder(config)
|
||||
classes = [
|
||||
TargetClassConfig(
|
||||
name="pippip",
|
||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
]
|
||||
encoder = build_sound_event_encoder(classes)
|
||||
result = encoder(sample_annotation)
|
||||
assert result == "pippip"
|
||||
|
||||
config = ClassesConfig(classes=[])
|
||||
encoder = build_sound_event_encoder(config)
|
||||
classes = []
|
||||
encoder = build_sound_event_encoder(classes)
|
||||
result = encoder(sample_annotation)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_encoder_from_config_valid(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
encoder = load_encoder_from_config(temp_yaml_path)
|
||||
# We cannot directly compare the function, so we test it.
|
||||
result = encoder(sample_annotation) # type: ignore
|
||||
assert result == "pippip"
|
||||
|
||||
|
||||
def test_load_encoder_from_config_invalid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: invalid_key
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
with pytest.raises(KeyError):
|
||||
load_encoder_from_config(temp_yaml_path)
|
||||
|
||||
|
||||
def test_get_default_class_name():
|
||||
assert _get_default_class_name("Myotis daubentonii") == "myodau"
|
||||
|
||||
|
||||
def test_get_default_classes():
|
||||
default_classes = _get_default_classes()
|
||||
assert len(default_classes) == len(DEFAULT_SPECIES_LIST)
|
||||
first_class = default_classes[0]
|
||||
assert isinstance(first_class, TargetClass)
|
||||
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
|
||||
assert first_class.tags[0].key == "class"
|
||||
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
||||
|
||||
|
||||
def test_build_decoder_from_config():
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
output_tags=[TagInfo(key="call_type", value="Echolocation")],
|
||||
)
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_sound_event_decoder(config)
|
||||
classes = [
|
||||
TargetClassConfig(
|
||||
name="pippip",
|
||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||
assign_tags=[data.Tag(key="call_type", value="Echolocation")],
|
||||
)
|
||||
]
|
||||
decoder = build_sound_event_decoder(classes)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == get_term("event")
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
# Test when output_tags is None, should fall back to tags
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
)
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_sound_event_decoder(config)
|
||||
classes = [
|
||||
TargetClassConfig(
|
||||
name="pippip",
|
||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
]
|
||||
decoder = build_sound_event_decoder(classes)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == get_term("species")
|
||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||
|
||||
# Test raise_on_unmapped=True
|
||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
|
||||
decoder = build_sound_event_decoder(classes, raise_on_unmapped=True)
|
||||
with pytest.raises(ValueError):
|
||||
decoder("unknown_class")
|
||||
|
||||
# Test raise_on_unmapped=False
|
||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
|
||||
decoder = build_sound_event_decoder(classes, raise_on_unmapped=False)
|
||||
tags = decoder("unknown_class")
|
||||
assert len(tags) == 0
|
||||
|
||||
|
||||
def test_load_decoder_from_config_valid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
output_tags:
|
||||
- key: call_type
|
||||
value: Echolocation
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
decoder = load_decoder_from_config(
|
||||
temp_yaml_path,
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == get_term("call_type")
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
|
||||
def test_build_generic_class_tags_from_config():
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
)
|
||||
],
|
||||
generic_class=[
|
||||
TagInfo(key="order", value="Chiroptera"),
|
||||
TagInfo(key="call_type", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
generic_tags = build_generic_class_tags(config)
|
||||
assert len(generic_tags) == 2
|
||||
assert generic_tags[0].term == get_term("order")
|
||||
assert generic_tags[0].value == "Chiroptera"
|
||||
assert generic_tags[1].term == get_term("call_type")
|
||||
assert generic_tags[1].value == "Echolocation"
|
||||
|
||||
@ -1,210 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Set
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
build_filter_from_rule,
|
||||
build_sound_event_filter,
|
||||
contains_tags,
|
||||
does_not_have_tags,
|
||||
equal_tags,
|
||||
has_any_tag,
|
||||
load_filter_config,
|
||||
load_filter_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, generic_class
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
) -> Callable[[List[str]], data.SoundEventAnnotation]:
|
||||
"""Helper function to create a SoundEventAnnotation with given tags."""
|
||||
|
||||
def factory(tags: List[str]) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=generic_class,
|
||||
value=tag,
|
||||
)
|
||||
for tag in tags
|
||||
],
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
|
||||
"""Helper function to create a set of data.Tag objects from a list of strings."""
|
||||
return {
|
||||
data.Tag(
|
||||
term=generic_class,
|
||||
value=tag,
|
||||
)
|
||||
for tag in tags
|
||||
}
|
||||
|
||||
|
||||
def test_has_any_tag(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag3"])
|
||||
assert has_any_tag(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag2", "tag4"])
|
||||
tags = create_tag_set(["tag1", "tag3"])
|
||||
assert has_any_tag(annotation, tags) is False
|
||||
|
||||
|
||||
def test_contains_tags(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||
tags = create_tag_set(["tag1", "tag2"])
|
||||
assert contains_tags(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag2", "tag3"])
|
||||
assert contains_tags(annotation, tags) is False
|
||||
|
||||
|
||||
def test_does_not_have_tags(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag3", "tag4"])
|
||||
assert does_not_have_tags(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag3"])
|
||||
assert does_not_have_tags(annotation, tags) is False
|
||||
|
||||
|
||||
def test_equal_tags(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag2"])
|
||||
assert equal_tags(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||
tags = create_tag_set(["tag1", "tag2"])
|
||||
assert equal_tags(annotation, tags) is False
|
||||
|
||||
|
||||
def test_build_filter_from_rule():
|
||||
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
|
||||
build_filter_from_rule(rule_any)
|
||||
|
||||
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
|
||||
build_filter_from_rule(rule_all)
|
||||
|
||||
rule_exclude = FilterRule(
|
||||
match_type="exclude", tags=[TagInfo(value="tag1")]
|
||||
)
|
||||
build_filter_from_rule(rule_exclude)
|
||||
|
||||
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
|
||||
build_filter_from_rule(rule_equal)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||
build_filter_from_rule(
|
||||
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def test_build_filter_from_config(create_annotation):
|
||||
config = FilterConfig(
|
||||
rules=[
|
||||
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
|
||||
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
||||
]
|
||||
)
|
||||
filter_from_config = build_sound_event_filter(config)
|
||||
|
||||
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||
assert filter_from_config(annotation_pass)
|
||||
|
||||
annotation_fail = create_annotation(["tag1"])
|
||||
assert not filter_from_config(annotation_fail)
|
||||
|
||||
|
||||
def test_load_filter_config(tmp_path: Path):
|
||||
test_config_path = tmp_path / "filtering.yaml"
|
||||
test_config_path.write_text(
|
||||
"""
|
||||
rules:
|
||||
- match_type: any
|
||||
tags:
|
||||
- value: tag1
|
||||
"""
|
||||
)
|
||||
config = load_filter_config(test_config_path)
|
||||
assert isinstance(config, FilterConfig)
|
||||
assert len(config.rules) == 1
|
||||
rule = config.rules[0]
|
||||
assert rule.match_type == "any"
|
||||
assert len(rule.tags) == 1
|
||||
assert rule.tags[0].value == "tag1"
|
||||
|
||||
|
||||
def test_load_filter_from_config(tmp_path: Path, create_annotation):
|
||||
test_config_path = tmp_path / "filtering.yaml"
|
||||
test_config_path.write_text(
|
||||
"""
|
||||
rules:
|
||||
- match_type: any
|
||||
tags:
|
||||
- value: tag1
|
||||
"""
|
||||
)
|
||||
|
||||
filter_result = load_filter_from_config(test_config_path)
|
||||
annotation = create_annotation(["tag1", "tag3"])
|
||||
assert filter_result(annotation)
|
||||
|
||||
test_config_path = tmp_path / "filtering.yaml"
|
||||
test_config_path.write_text(
|
||||
"""
|
||||
rules:
|
||||
- match_type: any
|
||||
tags:
|
||||
- value: tag2
|
||||
"""
|
||||
)
|
||||
|
||||
filter_result = load_filter_from_config(test_config_path)
|
||||
annotation = create_annotation(["tag1", "tag3"])
|
||||
assert filter_result(annotation) is False
|
||||
|
||||
|
||||
def test_default_filtering_over_example_dataset(
|
||||
example_annotations: List[data.ClipAnnotation],
|
||||
):
|
||||
targets = build_targets()
|
||||
|
||||
clip1 = example_annotations[0]
|
||||
clip2 = example_annotations[1]
|
||||
clip3 = example_annotations[2]
|
||||
|
||||
assert (
|
||||
sum(
|
||||
[targets.filter(sound_event) for sound_event in clip1.sound_events]
|
||||
)
|
||||
== 9
|
||||
)
|
||||
|
||||
assert (
|
||||
sum(
|
||||
[targets.filter(sound_event) for sound_event in clip2.sound_events]
|
||||
)
|
||||
== 15
|
||||
)
|
||||
|
||||
assert (
|
||||
sum(
|
||||
[targets.filter(sound_event) for sound_event in clip3.sound_events]
|
||||
)
|
||||
== 20
|
||||
)
|
||||
@ -8,6 +8,11 @@ from batdetect2.preprocess import (
|
||||
build_preprocessor,
|
||||
)
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectrogramConfig,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_ANCHOR,
|
||||
DEFAULT_FREQUENCY_SCALE,
|
||||
@ -548,14 +553,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
||||
|
||||
# Instantiate the mapper.
|
||||
preprocessor = build_preprocessor(
|
||||
PreprocessingConfig.model_validate(
|
||||
{
|
||||
"spectrogram": {
|
||||
"pcen": None,
|
||||
"spectral_mean_substraction": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
|
||||
)
|
||||
audio_loader = build_audio_loader()
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
@ -597,14 +595,13 @@ def test_build_roi_mapper_for_anchor_bbox():
|
||||
|
||||
def test_build_roi_mapper_for_peak_energy_bbox():
|
||||
# Given
|
||||
preproc_config = PreprocessingConfig.model_validate(
|
||||
{
|
||||
"spectrogram": {
|
||||
"pcen": None,
|
||||
"spectral_mean_substraction": True,
|
||||
"scale": "dB",
|
||||
}
|
||||
}
|
||||
preproc_config = PreprocessingConfig(
|
||||
spectrogram=SpectrogramConfig(
|
||||
transforms=[
|
||||
ScaleAmplitudeConfig(scale="db"),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
),
|
||||
)
|
||||
config = PeakEnergyBBoxMapperConfig(
|
||||
loading_buffer=0.99,
|
||||
|
||||
@ -1,46 +1,55 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
|
||||
|
||||
def test_can_override_default_roi_mapper_per_class(
|
||||
create_temp_yaml: Callable[..., Path],
|
||||
recording: data.Recording,
|
||||
sample_term_registry,
|
||||
):
|
||||
yaml_content = """
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: order
|
||||
value: Chiroptera
|
||||
assign_tags:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
|
||||
classification_targets:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
"""
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = load_target_config(config_path)
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
targets = build_targets(config)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
species = terms.get_term("species")
|
||||
assert species is not None
|
||||
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
@ -62,38 +71,47 @@ def test_can_override_default_roi_mapper_per_class(
|
||||
# TODO: rename this test function
|
||||
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||
create_temp_yaml,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: order
|
||||
value: Chiroptera
|
||||
assign_tags:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
|
||||
classification_targets:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
"""
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = load_target_config(config_path)
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
targets = build_targets(config)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
species = terms.get_term("species")
|
||||
assert species is not None
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
|
||||
@ -1,179 +0,0 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
load_terms_from_config,
|
||||
)
|
||||
|
||||
|
||||
def test_term_registry_initialization():
|
||||
registry = TermRegistry()
|
||||
assert registry._terms == {}
|
||||
|
||||
initial_terms = {
|
||||
"test_term": data.Term(name="test", label="Test", definition="test")
|
||||
}
|
||||
registry = TermRegistry(terms=initial_terms)
|
||||
assert registry._terms == initial_terms
|
||||
|
||||
|
||||
def test_term_registry_add_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
assert registry._terms["test_key"] == term
|
||||
|
||||
|
||||
def test_term_registry_get_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
retrieved_term = registry.get_term("test_key")
|
||||
assert retrieved_term == term
|
||||
|
||||
|
||||
def test_term_registry_add_custom_term():
|
||||
registry = TermRegistry()
|
||||
term = registry.add_custom_term(
|
||||
"custom_key", name="custom", label="Custom", definition="A custom term"
|
||||
)
|
||||
assert registry._terms["custom_key"] == term
|
||||
assert term.name == "custom"
|
||||
assert term.label == "Custom"
|
||||
assert term.definition == "A custom term"
|
||||
|
||||
|
||||
def test_term_registry_add_duplicate_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
with pytest.raises(KeyError):
|
||||
registry.add_term("test_key", term)
|
||||
|
||||
|
||||
def test_term_registry_get_term_not_found():
|
||||
registry = TermRegistry()
|
||||
with pytest.raises(KeyError):
|
||||
registry.get_term("non_existent_key")
|
||||
|
||||
|
||||
def test_term_registry_get_keys():
|
||||
registry = TermRegistry()
|
||||
term1 = data.Term(name="test1", label="Test1", definition="test")
|
||||
term2 = data.Term(name="test2", label="Test2", definition="test")
|
||||
registry.add_term("key1", term1)
|
||||
registry.add_term("key2", term2)
|
||||
keys = registry.get_keys()
|
||||
assert set(keys) == {"key1", "key2"}
|
||||
|
||||
|
||||
def test_get_term_from_key():
|
||||
term = terms.get_term_from_key("event")
|
||||
assert term == terms.call_type
|
||||
|
||||
custom_registry = TermRegistry()
|
||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||
custom_registry.add_term("custom_key", custom_term)
|
||||
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
|
||||
assert term == custom_term
|
||||
|
||||
|
||||
def test_get_term_keys():
|
||||
keys = terms.get_term_keys()
|
||||
assert "event" in keys
|
||||
assert "individual" in keys
|
||||
assert terms.GENERIC_CLASS_KEY in keys
|
||||
|
||||
custom_registry = TermRegistry()
|
||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||
custom_registry.add_term("custom_key", custom_term)
|
||||
keys = terms.get_term_keys(term_registry=custom_registry)
|
||||
assert "custom_key" in keys
|
||||
|
||||
|
||||
def test_tag_info_and_get_tag_from_info():
|
||||
tag_info = TagInfo(value="Myotis myotis", key="event")
|
||||
tag = terms.get_tag_from_info(tag_info)
|
||||
assert tag.value == "Myotis myotis"
|
||||
assert tag.term == terms.call_type
|
||||
|
||||
|
||||
def test_get_tag_from_info_key_not_found():
|
||||
tag_info = TagInfo(value="test", key="non_existent_key")
|
||||
with pytest.raises(KeyError):
|
||||
terms.get_tag_from_info(tag_info)
|
||||
|
||||
|
||||
def test_load_terms_from_config(tmp_path):
|
||||
term_registry = TermRegistry()
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "species",
|
||||
"name": "dwc:scientificName",
|
||||
"label": "Scientific Name",
|
||||
},
|
||||
{
|
||||
"key": "my_custom_term",
|
||||
"name": "soundevent:custom_term",
|
||||
"definition": "Describes a specific project attribute",
|
||||
},
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
loaded_terms = load_terms_from_config(
|
||||
config_file,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
assert "species" in loaded_terms
|
||||
assert "my_custom_term" in loaded_terms
|
||||
assert loaded_terms["species"].name == "dwc:scientificName"
|
||||
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
|
||||
|
||||
|
||||
def test_load_terms_from_config_file_not_found():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_terms_from_config("non_existent_file.yaml")
|
||||
|
||||
|
||||
def test_load_terms_from_config_validation_error(tmp_path):
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "species",
|
||||
"uri": "dwc:scientificName",
|
||||
"label": 123,
|
||||
}, # Invalid label type
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_terms_from_config(config_file)
|
||||
|
||||
|
||||
def test_load_terms_from_config_key_already_exists(tmp_path):
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "event",
|
||||
"uri": "dwc:scientificName",
|
||||
"label": "Scientific Name",
|
||||
}, # Duplicate key
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
load_terms_from_config(config_file)
|
||||
@ -1,363 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import (
|
||||
DeriveTagRule,
|
||||
MapValueRule,
|
||||
ReplaceRule,
|
||||
TagInfo,
|
||||
TransformConfig,
|
||||
build_transformation_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TermRegistry
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
build_transform_from_rule,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term_registry():
|
||||
return TermRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def derivation_registry():
|
||||
return DerivationRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term1(term_registry: TermRegistry) -> data.Term:
|
||||
return term_registry.add_custom_term(key="term1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term2(term_registry: TermRegistry) -> data.Term:
|
||||
return term_registry.add_custom_term(key="term2")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
term1: data.Term,
|
||||
) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
||||
)
|
||||
|
||||
|
||||
def test_map_value_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_map_value_rule_no_match(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"other_value": "value2"},
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
|
||||
|
||||
def test_replace_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term2: data.Term,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
rule = ReplaceRule(
|
||||
rule_type="replace",
|
||||
original=TagInfo(key="term1", value="value1"),
|
||||
replacement=TagInfo(key="term2", value="value2"),
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term2
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_replace_rule_no_match(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term2: data.Term,
|
||||
):
|
||||
rule = ReplaceRule(
|
||||
rule_type="replace",
|
||||
original=TagInfo(key="term1", value="wrong_value"),
|
||||
replacement=TagInfo(key="term2", value="value2"),
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].key == "term1"
|
||||
assert transformed_annotation.tags[0].term != term2
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
|
||||
|
||||
def test_build_transformation_from_config(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
config = TransformConfig(
|
||||
rules=[
|
||||
MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
),
|
||||
ReplaceRule(
|
||||
rule_type="replace",
|
||||
original=TagInfo(key="term2", value="value2"),
|
||||
replacement=TagInfo(key="term3", value="value3"),
|
||||
),
|
||||
]
|
||||
)
|
||||
term_registry.add_custom_term("term2")
|
||||
term_registry.add_custom_term("term3")
|
||||
transform = build_transformation_from_config(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
transformed_annotation = transform(annotation)
|
||||
assert transformed_annotation.tags[0].key == "term1"
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_derive_tag_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term1
|
||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||
|
||||
|
||||
def test_derive_tag_rule_keep_source_false(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
keep_source=False,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 1
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1_derived"
|
||||
|
||||
|
||||
def test_derive_tag_rule_target_term(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
term2: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
target_term_key="term2",
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term2
|
||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||
|
||||
|
||||
def test_derive_tag_rule_import_derivation(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term1: data.Term,
|
||||
tmp_path: Path,
|
||||
):
|
||||
# Create a dummy derivation function in a temporary file
|
||||
derivation_module_path = (
|
||||
tmp_path / "temp_derivation.py"
|
||||
) # Changed to /tmp since /home/santiago is not writable
|
||||
derivation_module_path.write_text(
|
||||
"""
|
||||
def my_imported_derivation(x: str) -> str:
|
||||
return x + "_imported"
|
||||
"""
|
||||
)
|
||||
# Ensure the temporary file is importable by adding its directory to sys.path
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="temp_derivation.my_imported_derivation",
|
||||
import_derivation=True,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term1
|
||||
assert transformed_annotation.tags[1].value == "value1_imported"
|
||||
|
||||
# Clean up the temporary file and sys.path
|
||||
sys.path.remove(str(tmp_path))
|
||||
|
||||
|
||||
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="nonexistent_derivation",
|
||||
)
|
||||
with pytest.raises(KeyError):
|
||||
build_transform_from_rule(rule, term_registry=term_registry)
|
||||
|
||||
|
||||
def test_build_transform_from_rule_invalid_rule_type():
|
||||
class InvalidRule:
|
||||
rule_type = "invalid"
|
||||
|
||||
rule = InvalidRule() # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
build_transform_from_rule(rule) # type: ignore
|
||||
|
||||
|
||||
def test_map_value_rule_target_term(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term2: data.Term,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
target_term_key="term2",
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term2
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_map_value_rule_target_term_none(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
target_term_key=None,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_derive_tag_rule_target_term_none(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
target_term_key=None,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term1
|
||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||
|
||||
|
||||
def test_build_transformation_from_config_empty(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
):
|
||||
config = TransformConfig(rules=[])
|
||||
transform = build_transformation_from_config(config)
|
||||
transformed_annotation = transform(annotation)
|
||||
assert transformed_annotation == annotation
|
||||
@ -1,170 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train.augmentations import (
|
||||
add_echo,
|
||||
mix_audio,
|
||||
)
|
||||
from batdetect2.train.clips import select_subclip
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
|
||||
|
||||
|
||||
def test_mix_examples(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
recording1 = create_recording()
|
||||
recording2 = create_recording()
|
||||
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
||||
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
mixed = mix_audio(
|
||||
example1,
|
||||
example2,
|
||||
weight=0.3,
|
||||
preprocessor=sample_preprocessor,
|
||||
)
|
||||
|
||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||
def test_mix_examples_of_different_durations(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
duration1: float,
|
||||
duration2: float,
|
||||
):
|
||||
recording1 = create_recording()
|
||||
recording2 = create_recording()
|
||||
|
||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
|
||||
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
|
||||
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
mixed = mix_audio(
|
||||
example1,
|
||||
example2,
|
||||
weight=0.3,
|
||||
preprocessor=sample_preprocessor,
|
||||
)
|
||||
|
||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||
|
||||
|
||||
def test_add_echo(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
recording1 = create_recording()
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
with_echo = add_echo(
|
||||
original,
|
||||
preprocessor=sample_preprocessor,
|
||||
delay=0.1,
|
||||
weight=0.3,
|
||||
)
|
||||
|
||||
assert with_echo.spectrogram.shape == original.spectrogram.shape
|
||||
torch.testing.assert_close(
|
||||
with_echo.size_heatmap,
|
||||
original.size_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
with_echo.class_heatmap,
|
||||
original.class_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
with_echo.detection_heatmap,
|
||||
original.detection_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
|
||||
|
||||
def test_selected_random_subclip_has_the_correct_width(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
recording1 = create_recording()
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
subclip = select_subclip(
|
||||
original,
|
||||
input_samplerate=256_000,
|
||||
output_samplerate=1000,
|
||||
start=0,
|
||||
duration=0.512,
|
||||
)
|
||||
assert subclip.spectrogram.shape[1] == 512
|
||||
@ -1,27 +0,0 @@
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train import generate_train_example
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipLabeller,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
|
||||
|
||||
def test_default_clip_size_is_correct(
|
||||
sample_clipper: ClipperProtocol,
|
||||
sample_labeller: ClipLabeller,
|
||||
sample_audio_loader: AudioLoader,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
):
|
||||
example = generate_train_example(
|
||||
clip_annotation=clip_annotation,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
clip, _, _ = sample_clipper(example)
|
||||
assert clip.spectrogram.shape == (1, 128, 256)
|
||||
@ -5,7 +5,6 @@ from soundevent import data
|
||||
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
|
||||
recording = data.Recording(
|
||||
@ -26,7 +25,8 @@ clip = data.Clip(
|
||||
|
||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
sample_target_config: TargetConfig,
|
||||
pippip_tag: TagInfo,
|
||||
pippip_tag: data.Tag,
|
||||
bat_tag: data.Tag,
|
||||
):
|
||||
config = sample_target_config.model_copy(
|
||||
update=dict(
|
||||
@ -49,14 +49,14 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
coordinates=[10, 10, 20, 30],
|
||||
),
|
||||
),
|
||||
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
||||
tags=[pippip_tag, bat_tag],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation,
|
||||
torch.rand([100, 100]),
|
||||
torch.rand([1, 100, 100]),
|
||||
min_freq=0,
|
||||
max_freq=100,
|
||||
targets=targets,
|
||||
@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
assert size_heatmap[1, 10, 10] == 20
|
||||
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
||||
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
||||
assert detection_heatmap[10, 10] == 1.0
|
||||
assert detection_heatmap[0, 10, 10] == 1.0
|
||||
|
||||
@ -4,14 +4,16 @@ import lightning as L
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||
from batdetect2.train.train import build_training_module
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
def build_default_module():
|
||||
model = build_model()
|
||||
config = FullTrainingConfig()
|
||||
return build_training_module(config)
|
||||
return build_training_module(model, config=config)
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
@ -32,14 +34,14 @@ def test_can_save_checkpoint(
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
|
||||
wav = torch.tensor(sample_audio_loader.load_clip(clip))
|
||||
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
|
||||
|
||||
spec1 = module.model.preprocessor(wav)
|
||||
spec2 = recovered.model.preprocessor(wav)
|
||||
|
||||
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||
|
||||
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
|
||||
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
|
||||
output1 = module(spec1.unsqueeze(0))
|
||||
output2 = recovered(spec2.unsqueeze(0))
|
||||
|
||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||
|
||||
@ -1,230 +0,0 @@
|
||||
import pytest
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def build_from_config(
|
||||
create_temp_yaml,
|
||||
):
|
||||
def build(yaml_content):
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
targets_config = load_target_config(config_path, field="targets")
|
||||
preprocessing_config = load_preprocessing_config(
|
||||
config_path,
|
||||
field="preprocessing",
|
||||
)
|
||||
labels_config = load_label_config(config_path, field="labels")
|
||||
postprocessing_config = load_postprocess_config(
|
||||
config_path,
|
||||
field="postprocessing",
|
||||
)
|
||||
|
||||
targets = build_targets(targets_config)
|
||||
preprocessor = build_preprocessor(preprocessing_config)
|
||||
labeller = build_clip_labeler(
|
||||
targets=targets,
|
||||
config=labels_config,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=postprocessing_config,
|
||||
)
|
||||
|
||||
return targets, preprocessor, labeller, postprocessor
|
||||
|
||||
return build
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object(
|
||||
sample_audio_loader: AudioLoader,
|
||||
build_from_config,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
],
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(
|
||||
clip_annotation,
|
||||
sample_audio_loader,
|
||||
preprocessor,
|
||||
labeller,
|
||||
)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||
0
|
||||
),
|
||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
|
||||
assert isinstance(predictions, data.ClipPrediction)
|
||||
assert len(predictions.sound_events) == 1
|
||||
|
||||
recovered = predictions.sound_events[0]
|
||||
assert recovered.sound_event.geometry is not None
|
||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||
recovered.sound_event.geometry.coordinates
|
||||
)
|
||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||
geometry.coordinates
|
||||
)
|
||||
|
||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
sample_audio_loader: AudioLoader,
|
||||
build_from_config,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(
|
||||
clip_annotation,
|
||||
sample_audio_loader,
|
||||
preprocessor,
|
||||
labeller,
|
||||
)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||
0
|
||||
),
|
||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
|
||||
assert isinstance(predictions, data.ClipPrediction)
|
||||
assert len(predictions.sound_events) == 1
|
||||
|
||||
recovered = predictions.sound_events[0]
|
||||
assert recovered.sound_event.geometry is not None
|
||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||
recovered.sound_event.geometry.coordinates
|
||||
)
|
||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||
geometry.coordinates
|
||||
)
|
||||
|
||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||
Loading…
Reference in New Issue
Block a user