Big changes in data module

This commit is contained in:
mbsantiago 2025-09-08 17:50:25 +01:00
parent cf6d0d1ccc
commit b7ae526071
32 changed files with 1678 additions and 2968 deletions

View File

@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
**Example YAML Configuration (e.g., inside a filter rule):**
```yaml
# ... inside a filtering configuration section ...

View File

@ -1,44 +1,47 @@
targets:
classes:
classes:
- name: myomys
tags:
- value: Myotis mystacinus
- name: pippip
tags:
- value: Pipistrellus pipistrellus
- name: eptser
tags:
- value: Eptesicus serotinus
- name: rhifer
tags:
- value: Rhinolophus ferrumequinum
roi:
name: anchor_bbox
anchor: top-left
generic_class:
detection_target:
name: bat
match_if:
name: all_of
conditions:
- name: has_tag
tag: { key: event, value: Echolocation }
- name: not
condition:
name: has_tag
tag: { key: class, value: Unknown }
assign_tags:
- key: class
value: Bat
filtering:
rules:
- match_type: all
tags:
- key: event
value: Echolocation
- match_type: exclude
tags:
- key: class
value: Unknown
classification_targets:
- name: myomys
tags:
- key: class
value: Myotis mystacinus
- name: pippip
tags:
- key: class
value: Pipistrellus pipistrellus
- name: eptser
tags:
- key: class
value: Eptesicus serotinus
- name: rhifer
tags:
- key: class
value: Rhinolophus ferrumequinum
roi:
name: anchor_bbox
anchor: top-left
preprocess:
audio:
samplerate: 256000
resample:
samplerate: 256000
enabled: True
method: "poly"
scale: false
center: true
duration: null
spectrogram:
stft:
@ -48,66 +51,63 @@ preprocess:
frequencies:
max_freq: 120000
min_freq: 10000
pcen:
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
scale: "amplitude"
size:
height: 128
resize_factor: 0.5
spectral_mean_substraction: true
peak_normalize: false
transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_substraction
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
min_freq: 10000
max_freq: 120000
top_k_per_sec: 200
labels:
sigma: 3
model:
input_height: 128
in_channels: 1
out_channels: 32
encoder:
layers:
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 32
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 64
- block_type: LayerGroup
- name: LayerGroup
layers:
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 128
- block_type: ConvBlock
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
layers:
- block_type: SelfAttention
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 64
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 32
- block_type: LayerGroup
- name: LayerGroup
layers:
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 32
- block_type: ConvBlock
- name: ConvBlock
out_channels: 32
train:
learning_rate: 0.001
t_max: 100
labels:
sigma: 3
dataloaders:
train:
batch_size: 8
@ -133,33 +133,35 @@ train:
weight: 0.1
logger:
logger_type: csv
save_dir: outputs/log/
name: logs
logger_type: tensorboard
# save_dir: outputs/log/
# name: logs
augmentations:
steps:
- augmentation_type: mix_audio
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- augmentation_type: add_echo
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
- augmentation_type: scale_volume
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- augmentation_type: warp
- name: warp
probability: 0.2
delta: 0.04
- augmentation_type: mask_time
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- augmentation_type: mask_freq
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3

View File

@ -6,3 +6,33 @@ sources:
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio
classes:
# Each class has a name
- name: rhihip
# Can be specified by some tags
tags:
- key: species
value: Rhinolophus hipposideros
- name: myotis
# Or if needed by more complex conditions
condition:
name: has_any_tag
tag:
- key: species
value: Myotis myotis
- key: species
value: Myotis nattereri
- key: species
value: Myotis daubentonii
# Specifies how to translate back the "class" into
# the tag system
# If tags are provided and output_tags are not then
# just use the provided tags
output_tags:
- key: genus
value: Myotis

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.8.1",
"soundevent[audio,geometry,plot]>=2.9.1",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",

View File

@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
and serialization capabilities.
"""
model_config = ConfigDict(extra="ignore")
model_config = ConfigDict(extra="forbid")
def to_yaml_string(
self,

View File

@ -0,0 +1,61 @@
from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel
__all__ = [
"Registry",
]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True)
class LogicProtocol(Generic[T_Config, T_Type], Protocol):
"""A generic protocol for the logic classes (conditions or transforms)."""
@classmethod
def from_config(cls, config: T_Config) -> T_Type: ...
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
class Registry(Generic[T_Type]):
"""A generic class to create and manage a registry of items."""
def __init__(self, name: str):
self._name = name
self._registry = {}
def register(self, config_cls: Type[T_Config]):
"""A decorator factory to register a new item."""
fields = config_cls.model_fields
if "name" not in fields:
raise ValueError("Configuration object must have a 'name' field.")
name = fields["name"].default
if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.")
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
self._registry[name] = logic_cls
return logic_cls
return decorator
def build(self, config: BaseModel) -> T_Type:
"""Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009
if name is None:
raise ValueError("Config does not have a name field")
if name not in self._registry:
raise NotImplementedError(
f"No {self._name} with name '{name}' is registered."
)
return self._registry[name].from_config(config)

View File

@ -14,8 +14,9 @@ format-specific loading function to retrieve the annotations as a standard
"""
from pathlib import Path
from typing import Optional, Union
from typing import Annotated, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.data.annotations.aoef import (
@ -42,10 +43,13 @@ __all__ = [
]
AnnotationFormats = Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
AnnotationFormats = Annotated[
Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
],
Field(discriminator="format"),
]
"""Type Alias representing all supported data source configurations.

View File

@ -0,0 +1,287 @@
from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence, Union
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
_conditions: Registry[SoundEventCondition] = Registry("condition")
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
@_conditions.register(HasTagConfig)
class HasTag:
def __init__(self, tag: data.Tag):
self.tag = tag
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tag in sound_event_annotation.tags
@classmethod
def from_config(cls, config: HasTagConfig):
return cls(tag=config.tag)
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: List[data.Tag]
@_conditions.register(HasAllTagsConfig)
class HasAllTags:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tags.issubset(sound_event_annotation.tags)
@classmethod
def from_config(cls, config: HasAllTagsConfig):
return cls(tags=config.tags)
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: List[data.Tag]
@_conditions.register(HasAnyTagConfig)
class HasAnyTag:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return bool(self.tags.intersection(sound_event_annotation.tags))
@classmethod
def from_config(cls, config: HasAnyTagConfig):
return cls(tags=config.tags)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
class DurationConfig(BaseConfig):
name: Literal["duration"] = "duration"
operator: Operator
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
if operator == "gte":
return lambda x: x >= value
if operator == "lt":
return lambda x: x < value
if operator == "lte":
return lambda x: x <= value
if operator == "eq":
return lambda x: x == value
raise ValueError(f"Invalid operator {operator}")
@_conditions.register(DurationConfig)
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
self.seconds = seconds
self._comparator = _build_comparator(self.operator, self.seconds)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
start_time, _, end_time, _ = compute_bounds(geometry)
duration = end_time - start_time
return self._comparator(duration)
@classmethod
def from_config(cls, config: DurationConfig):
return cls(operator=config.operator, seconds=config.seconds)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
operator: Operator
hertz: float
@_conditions.register(FrequencyConfig)
class Frequency:
def __init__(
self,
operator: Operator,
boundary: Literal["low", "high"],
hertz: float,
):
self.operator = operator
self.hertz = hertz
self.boundary = boundary
self._comparator = _build_comparator(self.operator, self.hertz)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
# Automatically false if geometry does not have a frequency range
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
return False
_, low_freq, _, high_freq = compute_bounds(geometry)
if self.boundary == "low":
return self._comparator(low_freq)
return self._comparator(high_freq)
@classmethod
def from_config(cls, config: FrequencyConfig):
return cls(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
@_conditions.register(AllOfConfig)
class AllOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return all(c(sound_event_annotation) for c in self.conditions)
@classmethod
def from_config(cls, config: AllOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return cls(conditions)
class AnyOfConfig(BaseConfig):
name: Literal["any_of"] = "any_of"
conditions: List["SoundEventConditionConfig"]
@_conditions.register(AnyOfConfig)
class AnyOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return any(c(sound_event_annotation) for c in self.conditions)
@classmethod
def from_config(cls, config: AnyOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return cls(conditions)
class NotConfig(BaseConfig):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
@_conditions.register(NotConfig)
class Not:
def __init__(self, condition: SoundEventCondition):
self.condition = condition
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return not self.condition(sound_event_annotation)
@classmethod
def from_config(cls, config: NotConfig):
condition = build_sound_event_condition(config.condition)
return cls(condition)
SoundEventConditionConfig = Annotated[
Union[
HasTagConfig,
HasAllTagsConfig,
HasAnyTagConfig,
DurationConfig,
FrequencyConfig,
AllOfConfig,
AnyOfConfig,
NotConfig,
],
Field(discriminator="name"),
]
def build_sound_event_condition(
config: SoundEventConditionConfig,
) -> SoundEventCondition:
return _conditions.build(config)
def filter_clip_annotation(
clip_annotation: data.ClipAnnotation,
condition: SoundEventCondition,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
sound_event
for sound_event in clip_annotation.sound_events
if condition(sound_event)
]
)
)

View File

@ -19,7 +19,7 @@ The core components are:
"""
from pathlib import Path
from typing import Annotated, List, Optional
from typing import List, Optional
from loguru import logger
from pydantic import Field
@ -31,6 +31,17 @@ from batdetect2.data.annotations import (
AnnotationFormats,
load_annotated_dataset,
)
from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
filter_clip_annotation,
)
from batdetect2.data.transforms import (
ApplyAll,
SoundEventTransformConfig,
build_sound_event_transform,
transform_clip_annotation,
)
from batdetect2.targets.terms import data_source
__all__ = [
@ -52,79 +63,68 @@ sources.
class DatasetConfig(BaseConfig):
"""Configuration model defining the structure of a BatDetect2 dataset.
This class is typically loaded from a YAML file and describes the components
of the dataset, including metadata and a list of data sources.
Attributes
----------
name : str
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
description : str
A longer description of the dataset's contents, origin, purpose, etc.
sources : List[AnnotationFormats]
A list defining the different data sources contributing to this
dataset. Each item in the list must conform to one of the Pydantic
models defined in the `AnnotationFormats` type union. The specific
model used for each source is determined by the mandatory `format`
field within the source's configuration, allowing BatDetect2 to use the
correct parser for different annotation styles.
"""
"""Configuration model defining the structure of a BatDetect2 dataset."""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]
sources: List[AnnotationFormats]
sound_event_filter: Optional[SoundEventConditionConfig] = None
sound_event_transforms: List[SoundEventTransformConfig] = Field(
default_factory=list
)
def load_dataset(
dataset: DatasetConfig,
config: DatasetConfig,
base_dir: Optional[Path] = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.
Iterates through each data source specified in the `dataset_config`,
delegates the loading and parsing of that source's annotations to
`batdetect2.data.annotations.load_annotated_dataset` (which handles
different data formats), and aggregates all resulting `ClipAnnotation`
objects into a single flat list.
Parameters
----------
dataset_config : DatasetConfig
The configuration object describing the dataset and its sources.
base_dir : Path, optional
An optional base directory path. If provided, relative paths for
metadata files or data directories within the `dataset_config`'s
sources might be resolved relative to this directory. Defaults to None.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects
from all specified sources.
Raises
------
Exception
Can raise various exceptions during the delegated loading process
(`load_annotated_dataset`) if files are not found, cannot be parsed
according to the specified format, or other I/O errors occur.
"""
"""Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = []
for source in dataset.sources:
condition = (
build_sound_event_condition(config.sound_event_filter)
if config.sound_event_filter is not None
else None
)
transform = (
ApplyAll(
[
build_sound_event_transform(step)
for step in config.sound_event_transforms
]
)
if config.sound_event_transforms
else None
)
for source in config.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
logger.debug(
"Loaded {num_examples} from dataset source '{source_name}'",
num_examples=len(annotated_source.clip_annotations),
source_name=source.name,
)
clip_annotations.extend(
insert_source_tag(clip_annotation, source)
for clip_annotation in annotated_source.clip_annotations
)
for clip_annotation in annotated_source.clip_annotations:
clip_annotation = insert_source_tag(clip_annotation, source)
if condition is not None:
clip_annotation = filter_clip_annotation(
clip_annotation,
condition,
)
if transform is not None:
clip_annotation = transform_clip_annotation(
clip_annotation,
transform,
)
clip_annotations.append(clip_annotation)
return clip_annotations
@ -161,7 +161,6 @@ def insert_source_tag(
)
# TODO: add documentation
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
return load_config(path=path, schema=DatasetConfig, field=field)

View File

@ -0,0 +1,250 @@
from collections.abc import Callable
from typing import Annotated, Dict, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.data.conditions import (
SoundEventCondition,
SoundEventConditionConfig,
build_sound_event_condition,
)
SoundEventTransform = Callable[
[data.SoundEventAnnotation],
data.SoundEventAnnotation,
]
_transforms: Registry[SoundEventTransform] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig):
name: Literal["set_frequency"] = "set_frequency"
boundary: Literal["low", "high"] = "low"
hertz: float
@_transforms.register(SetFrequencyBoundConfig)
class SetFrequencyBound:
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
self.hertz = hertz
self.boundary = boundary
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
if not isinstance(geometry, data.BoundingBox):
return sound_event_annotation
start_time, low_freq, end_time, high_freq = geometry.coordinates
if self.boundary == "low":
low_freq = self.hertz
high_freq = max(high_freq, low_freq)
elif self.boundary == "high":
high_freq = self.hertz
low_freq = min(high_freq, low_freq)
geometry = data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq],
)
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
return sound_event_annotation.model_copy(
update=dict(sound_event=sound_event)
)
@classmethod
def from_config(cls, config: SetFrequencyBoundConfig):
return cls(hertz=config.hertz, boundary=config.boundary)
class ApplyIfConfig(BaseConfig):
name: Literal["apply_if"] = "apply_if"
transform: "SoundEventTransformConfig"
condition: SoundEventConditionConfig
@_transforms.register(ApplyIfConfig)
class ApplyIf:
def __init__(
self,
condition: SoundEventCondition,
transform: SoundEventTransform,
):
self.condition = condition
self.transform = transform
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
if not self.condition(sound_event_annotation):
return sound_event_annotation
return self.transform(sound_event_annotation)
@classmethod
def from_config(cls, config: ApplyIfConfig):
transform = build_sound_event_transform(config.transform)
condition = build_sound_event_condition(config.condition)
return cls(condition=condition, transform=transform)
class ReplaceTagConfig(BaseConfig):
name: Literal["replace_tag"] = "replace_tag"
original: data.Tag
replacement: data.Tag
@_transforms.register(ReplaceTagConfig)
class ReplaceTag:
def __init__(
self,
original: data.Tag,
replacement: data.Tag,
):
self.original = original
self.replacement = replacement
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
tags = []
for tag in sound_event_annotation.tags:
if tag == self.original:
tags.append(self.replacement)
else:
tags.append(tag)
return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod
def from_config(cls, config: ReplaceTagConfig):
return cls(original=config.original, replacement=config.replacement)
class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str
value_mapping: Dict[str, str]
target_key: Optional[str] = None
@_transforms.register(MapTagValueConfig)
class MapTagValue:
def __init__(
self,
tag_key: str,
value_mapping: Dict[str, str],
target_key: Optional[str] = None,
):
self.tag_key = tag_key
self.value_mapping = value_mapping
self.target_key = target_key
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
tags = []
for tag in sound_event_annotation.tags:
if tag.key != self.tag_key:
tags.append(tag)
continue
value = self.value_mapping.get(tag.value)
if value is None:
tags.append(tag)
continue
if self.target_key is None:
tags.append(tag.model_copy(update=dict(value=value)))
else:
tags.append(
data.Tag(
key=self.target_key, # type: ignore
value=value,
)
)
return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod
def from_config(cls, config: MapTagValueConfig):
return cls(
tag_key=config.tag_key,
value_mapping=config.value_mapping,
target_key=config.target_key,
)
class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
@_transforms.register(ApplyAllConfig)
class ApplyAll:
def __init__(self, steps: List[SoundEventTransform]):
self.steps = steps
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
for step in self.steps:
sound_event_annotation = step(sound_event_annotation)
return sound_event_annotation
@classmethod
def from_config(cls, config: ApplyAllConfig):
steps = [build_sound_event_transform(step) for step in config.steps]
return cls(steps)
SoundEventTransformConfig = Annotated[
Union[
SetFrequencyBoundConfig,
ReplaceTagConfig,
MapTagValueConfig,
ApplyIfConfig,
ApplyAllConfig,
],
Field(discriminator="name"),
]
def build_sound_event_transform(
config: SoundEventTransformConfig,
) -> SoundEventTransform:
return _transforms.build(config)
def transform_clip_annotation(
clip_annotation: data.ClipAnnotation,
transform: SoundEventTransform,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
transform(sound_event)
for sound_event in clip_annotation.sound_events
]
)
)

View File

@ -56,7 +56,7 @@ __all__ = [
class SelfAttentionConfig(BaseConfig):
block_type: Literal["SelfAttention"] = "SelfAttention"
name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
@ -178,7 +178,7 @@ class SelfAttention(nn.Module):
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""
block_type: Literal["ConvBlock"] = "ConvBlock"
name: Literal["ConvBlock"] = "ConvBlock"
"""Discriminator field indicating the block type."""
out_channels: int
@ -300,7 +300,7 @@ class VerticalConv(nn.Module):
class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock."""
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
@ -390,7 +390,7 @@ class FreqCoordConvDownBlock(nn.Module):
class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock."""
block_type: Literal["StandardConvDown"] = "StandardConvDown"
name: Literal["StandardConvDown"] = "StandardConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
@ -460,7 +460,7 @@ class StandardConvDownBlock(nn.Module):
class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock."""
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
@ -569,7 +569,7 @@ class FreqCoordConvUpBlock(nn.Module):
class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock."""
block_type: Literal["StandardConvUp"] = "StandardConvUp"
name: Literal["StandardConvUp"] = "StandardConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
@ -664,13 +664,13 @@ LayerConfig = Annotated[
SelfAttentionConfig,
"LayerGroupConfig",
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroupConfig(BaseConfig):
block_type: Literal["LayerGroup"] = "LayerGroup"
name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig]
@ -686,7 +686,7 @@ def build_layer_from_config(
parameters derived from the config and the current pipeline state
(`input_height`, `in_channels`).
It uses the `block_type` field within the `config` object to determine
It uses the `name` field within the `config` object to determine
which block class to instantiate.
Parameters
@ -698,7 +698,7 @@ def build_layer_from_config(
config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
by its `block_type` field.
by its `name` field.
Returns
-------
@ -711,11 +711,11 @@ def build_layer_from_config(
Raises
------
NotImplementedError
If the `config.block_type` does not correspond to a known block type.
If the `config.name` does not correspond to a known block type.
ValueError
If parameters derived from the config are invalid for the block.
"""
if config.block_type == "ConvBlock":
if config.name == "ConvBlock":
return (
ConvBlock(
in_channels=in_channels,
@ -727,7 +727,7 @@ def build_layer_from_config(
input_height,
)
if config.block_type == "FreqCoordConvDown":
if config.name == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
@ -740,7 +740,7 @@ def build_layer_from_config(
input_height // 2,
)
if config.block_type == "StandardConvDown":
if config.name == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
@ -752,7 +752,7 @@ def build_layer_from_config(
input_height // 2,
)
if config.block_type == "FreqCoordConvUp":
if config.name == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
@ -765,7 +765,7 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "StandardConvUp":
if config.name == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
@ -777,7 +777,7 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "SelfAttention":
if config.name == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
@ -788,7 +788,7 @@ def build_layer_from_config(
input_height,
)
if config.block_type == "LayerGroup":
if config.name == "LayerGroup":
current_channels = in_channels
current_height = input_height
@ -804,4 +804,4 @@ def build_layer_from_config(
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.block_type}")
raise NotImplementedError(f"Unknown block type {config.name}")

View File

@ -128,7 +128,7 @@ class Bottleneck(nn.Module):
BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""

View File

@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
StandardConvUpConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the decoder sequence. Each item must be a valid block
config including a `block_type` field and necessary parameters like
config including a `name` field and necessary parameters like
`out_channels`. Input channels for each layer are inferred sequentially.
The list must contain at least one layer.
"""
@ -249,9 +249,9 @@ def build_decoder(
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
If `build_layer_from_config` encounters an unknown `name`.
"""
config = config or DEFAULT_DECODER_CONFIG

View File

@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
StandardConvDownConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
An ordered list of configuration objects, each defining one layer or
block in the encoder sequence. Each item must be a valid block config
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
`StandardConvDownConfig`) including a `block_type` field and necessary
`StandardConvDownConfig`) including a `name` field and necessary
parameters like `out_channels`. Input channels for each layer are
inferred sequentially. The list must contain at least one layer.
"""
@ -287,9 +287,9 @@ def build_encoder(
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
If `build_layer_from_config` encounters an unknown `name`.
"""
if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.")

View File

@ -1,6 +1,14 @@
"""Computes spectrograms from audio waveforms with configurable parameters."""
from typing import Annotated, Callable, List, Literal, Optional, Union
from typing import (
Annotated,
Callable,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
import torch
@ -306,7 +314,7 @@ class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: List[SpectrogramTransform] = Field(
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),

View File

@ -1,53 +1,26 @@
"""Main entry point for the BatDetect2 Target Definition subsystem.
This package (`batdetect2.targets`) provides the tools and configurations
necessary to define precisely what the BatDetect2 model should learn to detect,
classify, and localize from audio data. It involves several conceptual steps,
managed through configuration files and culminating in an executable pipeline:
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
3. **Transformation (`.transform`)**: Modifying tags (standardization,
derivation).
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
target position and size representations, and back.
5. **Class Definition (`.classes`)**: Mapping tags to target class names
(encoding) and mapping predicted names back to tags (decoding).
This module exposes the key components for users to configure and utilize this
target definition pipeline, primarily through the `TargetConfig` data structure
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
configured processing steps. The main way to create a functional `Targets`
object is via the `build_targets` or `load_targets` functions.
"""
"""BatDetect2 Target Definition system."""
from collections import Counter
from typing import Iterable, List, Optional, Tuple
from loguru import logger
from pydantic import Field
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.conditions import (
SoundEventCondition,
build_sound_event_condition,
)
from batdetect2.targets.classes import (
ClassesConfig,
DEFAULT_CLASSES,
DEFAULT_GENERIC_CLASS,
SoundEventDecoder,
SoundEventEncoder,
TargetClass,
build_generic_class_tags,
TargetClassConfig,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
load_classes_config,
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
SoundEventFilter,
build_sound_event_filter,
load_filter_config,
load_filter_from_config,
)
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
@ -55,106 +28,53 @@ from batdetect2.targets.rois import (
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.terms import (
TagInfo,
call_type,
get_tag_from_info,
individual,
)
from batdetect2.targets.transform import (
DerivationRegistry,
DeriveTagRule,
MapValueRule,
ReplaceRule,
SoundEventTransformation,
TransformConfig,
build_transformation_from_config,
default_derivation_registry,
get_derivation,
load_transformation_config,
load_transformation_from_config,
register_derivation,
)
from batdetect2.targets.terms import call_type, individual
from batdetect2.typing.targets import Position, Size, TargetProtocol
__all__ = [
"ClassesConfig",
"DEFAULT_TARGET_CONFIG",
"DeriveTagRule",
"FilterConfig",
"FilterRule",
"MapValueRule",
"AnchorBBoxMapperConfig",
"ROITargetMapper",
"ReplaceRule",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SoundEventTransformation",
"TagInfo",
"TargetClass",
"TargetClassConfig",
"TargetConfig",
"Targets",
"TransformConfig",
"build_generic_class_tags",
"build_roi_mapper",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_sound_event_filter",
"build_transformation_from_config",
"call_type",
"get_class_names_from_config",
"get_derivation",
"get_tag_from_info",
"individual",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
"load_filter_config",
"load_filter_from_config",
"load_target_config",
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
]
class TargetConfig(BaseConfig):
"""Unified configuration for the entire target definition pipeline.
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS)
This model aggregates the configurations for semantic processing (filtering,
transformation, class definition) and geometric processing (ROI mapping).
It serves as the primary input for building a complete `Targets` object
via `build_targets` or `load_targets`.
Attributes
----------
filtering : FilterConfig, optional
Configuration for filtering sound event annotations based on tags.
If None or omitted, no filtering is applied.
transforms : TransformConfig, optional
Configuration for transforming annotation tags
(mapping, derivation, etc.). If None or omitted, no tag transformations
are applied.
classes : ClassesConfig
Configuration defining the specific target classes, their tag matching
rules for encoding, their representative tags for decoding
(`output_tags`), and the definition of the generic class tags.
This section is mandatory.
roi : ROIConfig, optional
Configuration defining how geometric ROIs (e.g., bounding boxes) are
mapped to target representations (reference point, scaled size).
Controls `position`, `time_scale`, `frequency_scale`. If None or
omitted, default ROI mapping settings are used.
"""
filtering: FilterConfig = Field(default_factory=FilterConfig)
transforms: TransformConfig = Field(default_factory=TransformConfig)
classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG
classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
@field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def load_target_config(
path: data.PathLike,
@ -230,8 +150,7 @@ class Targets(TargetProtocol):
roi_mapper: ROITargetMapper,
class_names: list[str],
generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None,
transform_fn: Optional[SoundEventTransformation] = None,
filter_fn: Optional[SoundEventCondition] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
):
"""Initialize the Targets object.
@ -264,7 +183,6 @@ class Targets(TargetProtocol):
self._filter_fn = filter_fn
self._encode_fn = encode_fn
self._decode_fn = decode_fn
self._transform_fn = transform_fn
self._roi_mapper_overrides = roi_mapper_overrides or {}
for class_name in self._roi_mapper_overrides:
@ -336,27 +254,6 @@ class Targets(TargetProtocol):
"""
return self._decode_fn(class_label)
def transform(
self, sound_event: data.SoundEventAnnotation
) -> data.SoundEventAnnotation:
"""Apply the configured tag transformations to an annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation whose tags should be transformed.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags. If no
transformations were configured, the original annotation object is
returned.
"""
if self._transform_fn:
return self._transform_fn(sound_event)
return sound_event
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
@ -422,112 +319,14 @@ class Targets(TargetProtocol):
return self._roi_mapper.decode(position, size)
DEFAULT_CLASSES = [
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
),
TargetClass(
tags=[TagInfo(value="Myotis alcathoe")],
name="myoalc",
),
TargetClass(
tags=[TagInfo(value="Eptesicus serotinus")],
name="eptser",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus nathusii")],
name="pipnat",
),
TargetClass(
tags=[TagInfo(value="Barbastellus barbastellus")],
name="barbar",
),
TargetClass(
tags=[TagInfo(value="Myotis nattereri")],
name="myonat",
),
TargetClass(
tags=[TagInfo(value="Myotis daubentonii")],
name="myodau",
),
TargetClass(
tags=[TagInfo(value="Myotis brandtii")],
name="myobra",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pipistrellus")],
name="pippip",
),
TargetClass(
tags=[TagInfo(value="Myotis bechsteinii")],
name="myobec",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pygmaeus")],
name="pippyg",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus hipposideros")],
name="rhihip",
),
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
name="pleaur",
),
TargetClass(
tags=[TagInfo(value="Nyctalus noctula")],
name="nycnoc",
),
TargetClass(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
]
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
classes=DEFAULT_CLASSES,
generic_class=[TagInfo(value="Bat")],
)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
filtering=FilterConfig(
rules=[
FilterRule(
match_type="all",
tags=[TagInfo(key="event", value="Echolocation")],
),
FilterRule(
match_type="exclude",
tags=[
TagInfo(key="event", value="Feeding"),
TagInfo(key="event", value="Unknown"),
TagInfo(key="event", value="Not Bat"),
],
),
]
),
classes=DEFAULT_CLASSES_CONFIG,
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_GENERIC_CLASS,
roi=AnchorBBoxMapperConfig(),
)
def build_targets(
config: Optional[TargetConfig] = None,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
This factory function takes the unified `TargetConfig` and constructs all
@ -541,10 +340,6 @@ def build_targets(
----------
config : TargetConfig
The loaded and validated unified target configuration object.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
`batdetect2.targets.transform.derivation_registry`.
Returns
-------
@ -565,27 +360,18 @@ def build_targets(
lambda: config.to_yaml_string(),
)
filter_fn = (
build_sound_event_filter(config.filtering)
if config.filtering
else None
)
encode_fn = build_sound_event_encoder(config.classes)
decode_fn = build_sound_event_decoder(config.classes)
transform_fn = (
build_transformation_from_config(
config.transforms,
derivation_registry=derivation_registry,
)
if config.transforms
else None
)
filter_fn = build_sound_event_condition(config.detection_target.match_if)
encode_fn = build_sound_event_encoder(config.classification_targets)
decode_fn = build_sound_event_decoder(config.classification_targets)
roi_mapper = build_roi_mapper(config.roi)
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(config.classes)
class_names = get_class_names_from_config(config.classification_targets)
generic_class_tags = config.detection_target.assign_tags
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classes.classes
for class_config in config.classification_targets
if class_config.roi is not None
}
@ -596,7 +382,6 @@ def build_targets(
class_names=class_names,
roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
roi_mapper_overrides=roi_overrides,
)
@ -604,7 +389,6 @@ def build_targets(
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.
@ -619,9 +403,6 @@ def load_targets(
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
Returns
-------
@ -642,7 +423,7 @@ def load_targets(
config_path,
field=field,
)
return build_targets(config, derivation_registry=derivation_registry)
return build_targets(config)
def iterate_encoded_sound_events(
@ -658,8 +439,6 @@ def iterate_encoded_sound_events(
if geometry is None:
continue
sound_event = targets.transform(sound_event)
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)

View File

@ -1,251 +1,167 @@
from collections import Counter
from functools import partial
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple
from typing import Dict, List, Optional
from pydantic import Field, field_validator
from pydantic import Field, PrivateAttr, computed_field, model_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import (
GENERIC_CLASS_KEY,
TagInfo,
get_tag_from_info,
from batdetect2.configs import BaseConfig
from batdetect2.data.conditions import (
AllOfConfig,
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
NotConfig,
SoundEventCondition,
SoundEventConditionConfig,
build_sound_event_condition,
)
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
__all__ = [
"DEFAULT_SPECIES_LIST",
"build_generic_class_tags",
"build_sound_event_decoder",
"build_sound_event_encoder",
"get_class_names_from_config",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
]
DEFAULT_SPECIES_LIST = [
"Barbastella barbastellus",
"Eptesicus serotinus",
"Myotis alcathoe",
"Myotis bechsteinii",
"Myotis brandtii",
"Myotis daubentonii",
"Myotis mystacinus",
"Myotis nattereri",
"Nyctalus leisleri",
"Nyctalus noctula",
"Pipistrellus nathusii",
"Pipistrellus pipistrellus",
"Pipistrellus pygmaeus",
"Plecotus auritus",
"Plecotus austriacus",
"Rhinolophus ferrumequinum",
"Rhinolophus hipposideros",
]
"""A default list of common bat species names found in the UK."""
class TargetClass(BaseConfig):
"""Defines criteria for encoding annotations and decoding predictions.
Each instance represents one potential output class for the classification
model. It specifies:
1. A unique `name` for the class.
2. The tag conditions (`tags` and `match_type`) an annotation must meet to
be assigned this class name during training data preparation (encoding).
3. An optional, alternative set of tags (`output_tags`) to be used when
converting a model's prediction of this class name back into annotation
tags (decoding).
Attributes
----------
name : str
The unique name assigned to this target class (e.g., 'pippip',
'myodau', 'noise'). This name is used as the label during model
training and is the expected output from the model's prediction.
Should be unique across all TargetClass definitions in a configuration.
tags : List[TagInfo]
A list of one or more tags (defined using `TagInfo`) used to identify
if an existing annotation belongs to this class during encoding (data
preparation for training). The `match_type` attribute determines how
these tags are evaluated.
match_type : Literal["all", "any"], default="all"
Determines how the `tags` list is evaluated during encoding:
- "all": The annotation must have *all* the tags listed to match.
- "any": The annotation must have *at least one* of the tags listed
to match.
output_tags: Optional[List[TagInfo]], default=None
An optional list of tags (defined using `TagInfo`) to be assigned to a
new annotation when the model predicts this class `name`. If `None`
(default), the tags listed in the `tags` field will be used for
decoding. If provided, this list overrides the `tags` field for the
purpose of decoding predictions back into meaningful annotation tags.
This allows, for example, training on broader categories but decoding
to more specific representative tags.
"""
class TargetClassConfig(BaseConfig):
"""Defines a target class of sound events."""
name: str
tags: List[TagInfo] = Field(min_length=1)
match_type: Literal["all", "any"] = Field(default="all")
output_tags: Optional[List[TagInfo]] = None
condition_input: Optional[SoundEventConditionConfig] = Field(
alias="match_if",
default=None,
)
tags: Optional[List[data.Tag]] = None
assign_tags: List[data.Tag] = Field(default_factory=list)
roi: Optional[ROIMapperConfig] = None
_match_if: SoundEventConditionConfig = PrivateAttr()
def _get_default_classes() -> List[TargetClass]:
"""Generate a list of default target classes.
@model_validator(mode="after")
def _process_shorthands(self) -> "TargetClassConfig":
if self.tags and self.condition_input:
raise ValueError("Use either 'tags' or 'match_if', not both.")
Returns
-------
List[TargetClass]
A list of TargetClass objects, one for each species in
DEFAULT_SPECIES_LIST. The class names are simplified versions of the
species names.
"""
return [
TargetClass(
name=_get_default_class_name(value),
tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)],
)
for value in DEFAULT_SPECIES_LIST
]
def _get_default_class_name(species: str) -> str:
"""Generate a default class name from a species name.
Parameters
----------
species : str
The species name (e.g., "Myotis daubentonii").
Returns
-------
str
A simplified class name (e.g., "myodau").
The genus and species names are converted to lowercase,
the first three letters of each are taken, and concatenated.
"""
genus, species = species.strip().split(" ")
return f"{genus.lower()[:3]}{species.lower()[:3]}"
def _get_default_generic_class() -> List[TagInfo]:
"""Generate the default list of TagInfo objects for the generic class.
Provides a default set of tags used to represent the generic "Bat" category
when decoding predictions that didn't match a specific class.
Returns
-------
List[TagInfo]
A list containing default TagInfo objects, typically representing
`call_type: Echolocation` and `order: Chiroptera`.
"""
return [
TagInfo(key="call_type", value="Echolocation"),
TagInfo(key="order", value="Chiroptera"),
]
class ClassesConfig(BaseConfig):
"""Configuration defining target classes and the generic fallback category.
Holds the ordered list of specific target class definitions (`TargetClass`)
and defines the tags representing the generic category for sounds that pass
filtering but do not match any specific class.
The order of `TargetClass` objects in the `classes` list defines the
priority for classification during encoding. The system checks annotations
against these definitions sequentially and assigns the name of the *first*
matching class.
Attributes
----------
classes : List[TargetClass]
An ordered list of specific target class definitions. The order
determines matching priority (first match wins). Defaults to a
standard set of classes via `get_default_classes`.
generic_class : List[TagInfo]
A list of tags defining the "generic" or "unclassified but relevant"
category (e.g., representing a generic 'Bat' call that wasn't
assigned to a specific species). These tags are typically assigned
during decoding when a sound event was detected and passed filtering
but did not match any specific class rule defined in the `classes` list.
Defaults to a standard set of tags via `get_default_generic_class`.
Raises
------
ValueError
If validation fails (e.g., non-unique class names in the `classes`
list).
Notes
-----
- It is crucial that the `name` attribute of each `TargetClass` in the
`classes` list is unique. This configuration includes a validator to
enforce this uniqueness.
- The `generic_class` tags provide a baseline identity for relevant sounds
that don't fit into more specific defined categories.
"""
classes: List[TargetClass] = Field(default_factory=_get_default_classes)
generic_class: List[TagInfo] = Field(
default_factory=_get_default_generic_class
)
@field_validator("classes")
def check_unique_class_names(cls, v: List[TargetClass]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
if self.condition_input:
final_condition = self.condition_input
elif self.tags:
final_condition = HasAllTagsConfig(tags=self.tags)
else:
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
f"Class '{self.name}' must have a 'tags' or 'match_if' rule."
)
return v
self._match_if = final_condition
return self
@computed_field
@property
def match_if(self) -> SoundEventConditionConfig:
return self._match_if
def is_target_class(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
match_all: bool = True,
) -> bool:
"""Check if a sound event annotation matches a set of required tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
required_tags : Set[data.Tag]
A set of `soundevent.data.Tag` objects that define the class criteria.
match_all : bool, default=True
If True, checks if *all* `required_tags` are present in the
annotation's tags (subset check). If False, checks if *at least one*
of the `required_tags` is present (intersection check).
Returns
-------
bool
True if the annotation meets the tag criteria, False otherwise.
"""
annotation_tags = set(sound_event_annotation.tags)
if match_all:
return tags <= annotation_tags
return bool(tags & annotation_tags)
DEFAULT_GENERIC_CLASS = TargetClassConfig(
name="bat",
match_if=AllOfConfig(
conditions=[
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")),
NotConfig(
condition=HasAnyTagConfig(
tags=[
data.Tag(key="event", value="Feeding"),
data.Tag(key="event", value="Unknown"),
data.Tag(key="event", value="Not Bat"),
]
)
),
]
),
assign_tags=[
data.Tag(key="call_type", value="Echolocation"),
data.Tag(key="order", value="Chiroptera"),
],
)
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
DEFAULT_CLASSES = [
TargetClassConfig(
name="myomys",
tags=[data.Tag(key="class", value="Myotis mystacinus")],
),
TargetClassConfig(
name="myoalc",
tags=[data.Tag(key="class", value="Myotis alcathoe")],
),
TargetClassConfig(
name="eptser",
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
),
TargetClassConfig(
name="pipnat",
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
),
TargetClassConfig(
name="barbar",
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
),
TargetClassConfig(
name="myonat",
tags=[data.Tag(key="class", value="Myotis nattereri")],
),
TargetClassConfig(
name="myodau",
tags=[data.Tag(key="class", value="Myotis daubentonii")],
),
TargetClassConfig(
name="myobra",
tags=[data.Tag(key="class", value="Myotis brandtii")],
),
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
),
TargetClassConfig(
name="myobec",
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
),
TargetClassConfig(
name="pippyg",
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
),
TargetClassConfig(
name="rhihip",
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClassConfig(
name="nyclei",
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
),
TargetClassConfig(
name="rhifer",
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClassConfig(
name="pleaur",
tags=[data.Tag(key="class", value="Plecotus auritus")],
),
TargetClassConfig(
name="nycnoc",
tags=[data.Tag(key="class", value="Nyctalus noctula")],
),
TargetClassConfig(
name="pleaus",
tags=[data.Tag(key="class", value="Plecotus austriacus")],
),
]
def get_class_names_from_config(configs: List[TargetClassConfig]) -> List[str]:
"""Extract the list of class names from a ClassesConfig object.
Parameters
@ -258,324 +174,60 @@ def get_class_names_from_config(config: ClassesConfig) -> List[str]:
List[str]
An ordered list of unique class names defined in the configuration.
"""
return [class_info.name for class_info in config.classes]
return [class_info.name for class_info in configs]
def _encode_with_multiple_classifiers(
sound_event_annotation: data.SoundEventAnnotation,
classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]],
) -> Optional[str]:
"""Encode an annotation by checking against a list of classifiers.
def build_sound_event_encoder(
configs: List[TargetClassConfig],
) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration."""
conditions = {
class_config.name: build_sound_event_condition(class_config.match_if)
for class_config in configs
}
Internal helper function used by the `SoundEventEncoder`. It iterates
through the provided list of (class_name, classifier_function) pairs.
Returns the name associated with the first classifier function that
returns True for the given annotation.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to encode.
classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]]
An ordered list where each tuple contains a class name and a function
that returns True if the annotation matches that class. The order
determines priority.
Returns
-------
str or None
The name of the first matching class, or None if no classifier matches.
"""
for class_name, classifier in classifiers:
if classifier(sound_event_annotation):
return class_name
return None
return SoundEventClassifier(conditions)
def build_sound_event_encoder(config: ClassesConfig) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration.
class SoundEventClassifier:
def __init__(self, mapping: Dict[str, SoundEventCondition]):
self.mapping = mapping
The returned encoder function iterates through the class definitions in the
order specified in the config. It assigns an annotation the name of the
first class definition it matches.
Parameters
----------
config : ClassesConfig
The loaded and validated classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance used to look up term keys specified in the
`TagInfo` objects within the configuration. Defaults to the global
`batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
A callable function that takes a `SoundEventAnnotation` and returns
an optional string representing the matched class name, or None if no
class matches.
Raises
------
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry`.
"""
binary_classifiers = [
(
class_info.name,
partial(
is_target_class,
tags={
get_tag_from_info(tag_info) for tag_info in class_info.tags
},
match_all=class_info.match_type == "all",
),
)
for class_info in config.classes
]
return partial(
_encode_with_multiple_classifiers,
classifiers=binary_classifiers,
)
def _decode_class(
name: str,
mapping: Dict[str, List[data.Tag]],
raise_on_error: bool = True,
) -> List[data.Tag]:
"""Decode a class name into a list of representative tags using a mapping.
Internal helper function used by the `SoundEventDecoder`. Looks up the
provided class `name` in the `mapping` dictionary.
Parameters
----------
name : str
The class name to decode.
mapping : Dict[str, List[data.Tag]]
A dictionary mapping class names to lists of `soundevent.data.Tag`
objects.
raise_on_error : bool, default=True
If True, raises a ValueError if the `name` is not found in the
`mapping`. If False, returns an empty list if the `name` is not found.
Returns
-------
List[data.Tag]
The list of tags associated with the class name, or an empty list if
not found and `raise_on_error` is False.
Raises
------
ValueError
If `name` is not found in `mapping` and `raise_on_error` is True.
"""
if name not in mapping and raise_on_error:
raise ValueError(f"Class {name} not found in mapping.")
if name not in mapping:
return []
return mapping[name]
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> Optional[str]:
for name, condition in self.mapping.items():
if condition(sound_event_annotation):
return name
def build_sound_event_decoder(
config: ClassesConfig,
configs: List[TargetClassConfig],
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration.
Creates a callable `SoundEventDecoder` that maps a class name string
back to a list of representative `soundevent.data.Tag` objects based on
the `ClassesConfig`. It uses the `output_tags` field if provided in a
`TargetClass`, otherwise falls back to the `tags` field.
Parameters
----------
config : ClassesConfig
The loaded and validated classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance used to look up term keys. Defaults to the
global `batdetect2.targets.terms.registry`.
raise_on_unmapped : bool, default=False
If True, the returned decoder function will raise a ValueError if asked
to decode a class name that is not in the configuration. If False, it
will return an empty list for unmapped names.
Returns
-------
SoundEventDecoder
A callable function that takes a class name string and returns a list
of `soundevent.data.Tag` objects.
Raises
------
KeyError
If a term key specified in the configuration (`output_tags`, `tags`, or
`generic_class`) is not found in the provided `term_registry`.
"""
mapping = {}
for class_info in config.classes:
tags_to_use = (
class_info.output_tags
if class_info.output_tags is not None
else class_info.tags
)
mapping[class_info.name] = [
get_tag_from_info(tag_info) for tag_info in tags_to_use
]
return partial(
_decode_class,
mapping=mapping,
raise_on_error=raise_on_unmapped,
)
"""Build a sound event decoder function from the classes configuration."""
mapping = {
class_config.name: class_config.assign_tags for class_config in configs
}
return TagDecoder(mapping, raise_on_unknown=raise_on_unmapped)
def build_generic_class_tags(config: ClassesConfig) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config.
class TagDecoder:
def __init__(
self,
mapping: Dict[str, List[data.Tag]],
raise_on_unknown: bool = True,
):
self.mapping = mapping
self.raise_on_unknown = raise_on_unknown
Converts the list of `TagInfo` objects defined in `config.generic_class`
into a list of `soundevent.data.Tag` objects using the term registry.
def __call__(self, class_name: str) -> List[data.Tag]:
tags = self.mapping.get(class_name)
Parameters
----------
config : ClassesConfig
The loaded classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance for term lookups. Defaults to the global
`batdetect2.targets.terms.registry`.
if tags is None:
if self.raise_on_unknown:
raise ValueError("Invalid class name")
Returns
-------
List[data.Tag]
The list of fully constructed tags representing the generic class.
tags = []
Raises
------
KeyError
If a term key specified in `config.generic_class` is not found in the
provided `term_registry`.
"""
return [get_tag_from_info(tag_info) for tag_info in config.generic_class]
def load_classes_config(
path: data.PathLike,
field: Optional[str] = None,
) -> ClassesConfig:
"""Load the target classes configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
ClassesConfig
The loaded and validated classes configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
"""
return load_config(path, schema=ClassesConfig, field=field)
def load_encoder_from_config(
path: data.PathLike, field: Optional[str] = None
) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file.
This is a convenience function that combines loading the `ClassesConfig`
from a file and building the final `SoundEventEncoder` function.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
term_registry : TermRegistry, optional
The TermRegistry instance used for term lookups. Defaults to the
global `batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
The final encoder function ready to classify annotations.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_sound_event_encoder(config)
def load_decoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file.
This is a convenience function that combines loading the `ClassesConfig`
from a file and building the final `SoundEventDecoder` function.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
term_registry : TermRegistry, optional
The TermRegistry instance used for term lookups. Defaults to the
global `batdetect2.targets.terms.registry`.
raise_on_unmapped : bool, default=False
If True, the returned decoder function will raise a ValueError if asked
to decode a class name that is not in the configuration. If False, it
will return an empty list for unmapped names.
Returns
-------
SoundEventDecoder
The final decoder function ready to convert class names back into tags.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_sound_event_decoder(
config,
raise_on_unmapped=raise_on_unmapped,
)
return tags

View File

@ -1,293 +0,0 @@
import logging
from functools import partial
from typing import List, Literal, Optional, Set
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
get_tag_from_info,
)
from batdetect2.typing.targets import SoundEventFilter
__all__ = [
"FilterConfig",
"FilterRule",
"build_sound_event_filter",
"build_filter_from_rule",
"load_filter_config",
"load_filter_from_config",
]
logger = logging.getLogger(__name__)
class FilterRule(BaseConfig):
"""Defines a single rule for filtering sound event annotations.
Based on the `match_type`, this rule checks if the tags associated with a
sound event annotation meet certain criteria relative to the `tags` list
defined in this rule.
Attributes
----------
match_type : Literal["any", "all", "exclude", "equal"]
Determines how the `tags` list is used:
- "any": Pass if the annotation has at least one tag from the list.
- "all": Pass if the annotation has all tags from the list (it can
have others too).
- "exclude": Pass if the annotation has none of the tags from the list.
- "equal": Pass if the annotation's tags are exactly the same set as
provided in the list.
tags : List[TagInfo]
A list of tags (defined using TagInfo for configuration) that this
rule operates on.
"""
match_type: Literal["any", "all", "exclude", "equal"]
tags: List[TagInfo]
def has_any_tag(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation has at least one of the specified tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags to look for.
Returns
-------
bool
True if the annotation has one or more tags from the specified set,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return bool(tags & sound_event_tags)
def contains_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation contains all of the specified tags.
The annotation may have additional tags beyond those specified.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags that must all be present in the annotation.
Returns
-------
bool
True if the annotation's tags are a superset of the specified tags,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return tags <= sound_event_tags
def does_not_have_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
):
"""Check if the annotation has none of the specified tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags that must *not* be present in the annotation.
Returns
-------
bool
True if the annotation has zero tags in common with the specified set,
False otherwise.
"""
return not has_any_tag(sound_event_annotation, tags)
def equal_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation's tags are exactly equal to the specified set.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The exact set of tags the annotation must have.
Returns
-------
bool
True if the annotation's tags set is identical to the specified set,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return tags == sound_event_tags
def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule.
Parameters
----------
rule : FilterRule
The filter rule configuration object.
Returns
-------
SoundEventFilter
A function that takes a SoundEventAnnotation and returns True if it
passes the rule, False otherwise.
Raises
------
ValueError
If the rule contains an invalid `match_type`.
"""
tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags}
if rule.match_type == "any":
return partial(has_any_tag, tags=tag_set)
if rule.match_type == "all":
return partial(contains_tags, tags=tag_set)
if rule.match_type == "exclude":
return partial(does_not_have_tags, tags=tag_set)
if rule.match_type == "equal":
return partial(equal_tags, tags=tag_set)
raise ValueError(
f"Invalid match type {rule.match_type}. Valid types "
"are: 'any', 'all', 'exclude' and 'equal'"
)
def _passes_all_filters(
sound_event_annotation: data.SoundEventAnnotation,
filters: List[SoundEventFilter],
) -> bool:
"""Check if the annotation passes all provided filters.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
filters : List[SoundEventFilter]
A list of filter functions to apply.
Returns
-------
bool
True if the annotation passes all filters, False otherwise.
"""
for filter_fn in filters:
if not filter_fn(sound_event_annotation):
logging.debug(
f"Sound event annotation {sound_event_annotation.uuid} "
f"excluded due to rule {filter_fn}",
)
return False
return True
class FilterConfig(BaseConfig):
"""Configuration model for defining a list of filter rules.
Attributes
----------
rules : List[FilterRule]
A list of FilterRule objects. An annotation must pass all rules in
this list to be considered valid by the filter built from this config.
"""
rules: List[FilterRule] = Field(default_factory=list)
def build_sound_event_filter(
config: FilterConfig,
) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object.
Creates individual filter functions for each rule in the configuration
and merges them using AND logic.
Parameters
----------
config : FilterConfig
The configuration object containing the list of filter rules.
Returns
-------
SoundEventFilter
A single callable filter function that applies all defined rules.
"""
filters = [build_filter_from_rule(rule) for rule in config.rules]
return partial(_passes_all_filters, filters=filters)
def load_filter_config(
path: data.PathLike, field: Optional[str] = None
) -> FilterConfig:
"""Loads the filter configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : Optional[str], optional
If the filter configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
FilterConfig
The loaded and validated filter configuration object.
"""
return load_config(path, schema=FilterConfig, field=field)
def load_filter_from_config(
path: data.PathLike, field: Optional[str] = None,
) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function.
This is a convenience function that combines loading the configuration
and building the final callable filter function.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : Optional[str], optional
If the filter configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
SoundEventFilter
The final merged filter function ready to be used.
"""
config = load_filter_config(path=path, field=field)
return build_sound_event_filter(config)

View File

@ -1,23 +1,11 @@
"""Manages the vocabulary for defining training targets.
"""Manages the vocabulary for defining training targets."""
This module provides the necessary tools to declare, register, and manage the
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
sub-package. It establishes a consistent vocabulary for filtering,
transforming, and classifying sound events based on their annotations (Tags).
Terms can be pre-defined, loaded from the `soundevent.terms` library or defined
programmatically.
"""
from pydantic import BaseModel
from soundevent import data, terms
__all__ = [
"call_type",
"individual",
"data_source",
"get_tag_from_info",
"TagInfo",
]
# The default key used to reference the 'generic_class' term.
@ -85,54 +73,3 @@ terms.register_term_set(
),
override_existing=True,
)
class TagInfo(BaseModel):
"""Represents information needed to define a specific Tag.
This model is typically used in configuration files (e.g., YAML) to
specify tags used for filtering, target class definition, or associating
tags with output classes. It links a tag value to a term definition
via the term's registry key.
Attributes
----------
value : str
The value of the tag (e.g., "Myotis myotis", "Echolocation").
key : str, default="class"
The key (alias) of the term associated with this tag. Defaults to
"class", implying it represents a classification target label by
default.
"""
value: str
key: str = GENERIC_CLASS_KEY
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
Looks up the term using the key in the provided `tag_info` and constructs a
Tag object.
Parameters
----------
tag_info : TagInfo
The TagInfo object containing the value and term key.
Returns
-------
soundevent.data.Tag
A soundevent.data.Tag object corresponding to the input info.
Raises
------
KeyError
If the term key specified in `tag_info.key` is not found.
"""
term = terms.get_term(tag_info.key)
if not term:
raise KeyError(f"Key {tag_info.key} not found")
return data.Tag(term=term, value=tag_info.value)

View File

@ -1,689 +0,0 @@
import importlib
from functools import partial
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
from pydantic import Field
from soundevent import data, terms
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import TagInfo, get_tag_from_info
__all__ = [
"DerivationRegistry",
"DeriveTagRule",
"MapValueRule",
"ReplaceRule",
"SoundEventTransformation",
"TransformConfig",
"build_transform_from_rule",
"build_transformation_from_config",
"default_derivation_registry",
"get_derivation",
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
]
SoundEventTransformation = Callable[
[data.SoundEventAnnotation], data.SoundEventAnnotation
]
"""Type alias for a sound event transformation function.
A function that accepts a sound event annotation object and returns a
(potentially) modified sound event annotation object. Transformations
should generally return a copy of the annotation rather than modifying
it in place.
"""
Derivation = Callable[[str], str]
"""Type alias for a derivation function.
A function that accepts a single string (typically a tag value) and returns
a new string (the derived value).
"""
class MapValueRule(BaseConfig):
"""Configuration for mapping specific values of a source term.
This rule replaces tags matching a specific term and one of the
original values with a new tag (potentially having a different term)
containing the corresponding replacement value. Useful for standardizing
or grouping tag values.
Attributes
----------
rule_type : Literal["map_value"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tags' values
should be checked against the `value_mapping`.
value_mapping : Dict[str, str]
A dictionary mapping original string values to replacement string
values. Only tags whose value is a key in this dictionary will be
affected.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the *output*
tag. If None (default), the output tag uses the same term as the
source (`source_term_key`). If provided, the term of the affected
tag is changed to this target term upon replacement.
"""
rule_type: Literal["map_value"] = "map_value"
source_term_key: str
value_mapping: Dict[str, str]
target_term_key: Optional[str] = None
def map_value_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
mapping: Dict[str, str],
) -> data.SoundEventAnnotation:
"""Apply a value mapping transformation to an annotation's tags.
Iterates through the annotation's tags. If a tag matches the `source_term`
and its value is found in the `mapping`, it is replaced by a new tag with
the `target_term` and the mapped value. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values should be mapped.
target_term : data.Term
The term to use for the newly created tags after mapping.
mapping : Dict[str, str]
The dictionary mapping original values to new values.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags.
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term or tag.value not in mapping:
tags.append(tag)
continue
new_value = mapping[tag.value]
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class DeriveTagRule(BaseConfig):
"""Configuration for deriving a new tag from an existing tag's value.
This rule applies a specified function (`derivation_function`) to the
value of tags matching the `source_term_key`. It then adds a new tag
with the `target_term_key` and the derived value.
Attributes
----------
rule_type : Literal["derive_tag"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tag values
will be used as input to the derivation function.
derivation_function : str
The name/key identifying the derivation function to use. This can be
a key registered in the `DerivationRegistry` or, if
`import_derivation` is True, a full Python path like
`'my_module.my_submodule.my_function'`.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the new tag
that will be created with the derived value. If None (default), the
derived tag uses the same term as the source (`source_term_key`),
effectively performing an in-place value transformation.
import_derivation : bool, default=False
If True, treat `derivation_function` as a Python import path and
attempt to dynamically import it if not found in the registry.
Requires the function to be accessible in the Python environment.
keep_source : bool, default=True
If True, the original source tag (whose value was used for derivation)
is kept in the annotation's tag list alongside the newly derived tag.
If False, the original source tag is removed.
"""
rule_type: Literal["derive_tag"] = "derive_tag"
source_term_key: str
derivation_function: str
target_term_key: Optional[str] = None
import_derivation: bool = False
keep_source: bool = True
def derivation_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
derivation: Derivation,
keep_source: bool = True,
) -> data.SoundEventAnnotation:
"""Apply a derivation transformation to an annotation's tags.
Iterates through the annotation's tags. For each tag matching the
`source_term`, its value is passed to the `derivation` function.
A new tag is created with the `target_term` and the derived value,
and added to the output tag list. The original source tag is kept
or discarded based on `keep_source`. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values serve as input for the derivation.
target_term : data.Term
The term to use for the newly created derived tags.
derivation : Derivation
The function to apply to the source tag's value.
keep_source : bool, default=True
Whether to keep the original source tag in the output.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags (including derived
ones).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term:
tags.append(tag)
continue
if keep_source:
tags.append(tag)
new_value = derivation(tag.value)
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class ReplaceRule(BaseConfig):
"""Configuration for exactly replacing one specific tag with another.
This rule looks for an exact match of the `original` tag (both term and
value) and replaces it with the specified `replacement` tag.
Attributes
----------
rule_type : Literal["replace"]
Discriminator field identifying this rule type.
original : TagInfo
The exact tag to search for, defined using its value and term key.
replacement : TagInfo
The tag to substitute in place of the original tag, defined using
its value and term key.
"""
rule_type: Literal["replace"] = "replace"
original: TagInfo
replacement: TagInfo
def replace_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source: data.Tag,
target: data.Tag,
) -> data.SoundEventAnnotation:
"""Apply an exact tag replacement transformation.
Iterates through the annotation's tags. If a tag exactly matches the
`source` tag, it is replaced by the `target` tag. Other tags are kept
unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source : data.Tag
The exact tag to find and replace.
target : data.Tag
The tag to replace the source tag with.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the replaced tag (if found).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag == source:
tags.append(target)
else:
tags.append(tag)
return sound_event_annotation.model_copy(update=dict(tags=tags))
class TransformConfig(BaseConfig):
"""Configuration model for defining a sequence of transformation rules.
Attributes
----------
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
A list of transformation rules to apply. The rules are applied
sequentially in the order they appear in the list. The output of
one rule becomes the input for the next. The `rule_type` field
discriminates between the different rule models.
"""
rules: List[
Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
] = Field(
default_factory=list,
)
class DerivationRegistry(Mapping[str, Derivation]):
"""A registry for managing named derivation functions.
Derivation functions are callables that take a string value and return
a transformed string value, used by `DeriveTagRule`. This registry
allows functions to be registered with a key and retrieved later.
"""
def __init__(self):
"""Initialize an empty DerivationRegistry."""
self._derivations: Dict[str, Derivation] = {}
def __getitem__(self, key: str) -> Derivation:
"""Retrieve a derivation function by key."""
return self._derivations[key]
def __len__(self) -> int:
"""Return the number of registered derivation functions."""
return len(self._derivations)
def __iter__(self):
"""Return an iterator over the keys of registered functions."""
return iter(self._derivations)
def register(self, key: str, derivation: Derivation) -> None:
"""Register a derivation function with a unique key.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
if key in self._derivations:
raise KeyError(
f"A derivation with the provided key {key} already exists"
)
self._derivations[key] = derivation
def get_derivation(self, key: str) -> Derivation:
"""Retrieve a derivation function by its registered key.
Parameters
----------
key : str
The key of the derivation function to retrieve.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If no derivation function with the specified key is registered.
"""
try:
return self._derivations[key]
except KeyError as err:
raise KeyError(
f"No derivation with key {key} is registered."
) from err
def get_keys(self) -> List[str]:
"""Get a list of all registered derivation function keys.
Returns
-------
List[str]
The keys of all registered functions.
"""
return list(self._derivations.keys())
def get_derivations(self) -> List[Derivation]:
"""Get a list of all registered derivation functions.
Returns
-------
List[Derivation]
The registered derivation function objects.
"""
return list(self._derivations.values())
default_derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key
in `DeriveTagRule` configuration.
"""
def get_derivation(
key: str,
import_derivation: bool = False,
registry: Optional[DerivationRegistry] = None,
):
"""Retrieve a derivation function by key, optionally importing it.
First attempts to find the function in the provided `registry`.
If not found and `import_derivation` is True, attempts to dynamically
import the function using the `key` as a full Python path
(e.g., 'my_module.submodule.my_func').
Parameters
----------
key : str
The key or Python path of the derivation function.
import_derivation : bool, default=False
If True, attempt dynamic import if key is not in the registry.
registry : DerivationRegistry, optional
The registry instance to check first. Defaults to the global
`derivation_registry`.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If the key is not found in the registry and either
`import_derivation` is False or the dynamic import fails.
ImportError
If dynamic import fails specifically due to module not found.
AttributeError
If dynamic import fails because the function name isn't in the module.
"""
registry = registry or default_derivation_registry
if not import_derivation or key in registry:
return registry.get_derivation(key)
try:
module_path, func_name = key.rsplit(".", 1)
module = importlib.import_module(module_path)
func = getattr(module, func_name)
return func
except ImportError as err:
raise KeyError(
f"Unable to load derivation '{key}'. Check the path and ensure "
"it points to a valid callable function in an importable module."
) from err
TranformationRule = Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
def build_transform_from_rule(
rule: TranformationRule,
derivation_registry: Optional[DerivationRegistry] = None,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
Selects the appropriate transformation logic based on the rule's
`rule_type`, fetches necessary terms and derivation functions, and
returns a partially applied function ready to transform an annotation.
Parameters
----------
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
The configuration object for a single transformation rule.
registry : DerivationRegistry, optional
The derivation registry to use for `DeriveTagRule`. Defaults to the
global `derivation_registry`.
Returns
-------
SoundEventTransformation
A callable that applies the specified rule to a SoundEventAnnotation.
Raises
------
KeyError
If required term keys or derivation keys are not found.
ValueError
If the rule has an unknown `rule_type`.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails.
"""
if rule.rule_type == "replace":
source = get_tag_from_info(rule.original)
target = get_tag_from_info(rule.replacement)
return partial(replace_tag_transform, source=source, target=target)
if rule.rule_type == "derive_tag":
source_term = terms.get_term(rule.source_term_key)
target_term = (
terms.get_term(rule.target_term_key)
if rule.target_term_key
else source_term
)
if source_term is None or target_term is None:
raise KeyError("Terms not found")
derivation = get_derivation(
key=rule.derivation_function,
import_derivation=rule.import_derivation,
registry=derivation_registry,
)
return partial(
derivation_tag_transform,
source_term=source_term,
target_term=target_term,
derivation=derivation,
keep_source=rule.keep_source,
)
if rule.rule_type == "map_value":
source_term = terms.get_term(rule.source_term_key)
target_term = (
terms.get_term(rule.target_term_key)
if rule.target_term_key
else source_term
)
if source_term is None or target_term is None:
raise KeyError("Terms not found")
return partial(
map_value_transform,
source_term=source_term,
target_term=target_term,
mapping=rule.value_mapping,
)
# Handle unknown rule type
valid_options = ["replace", "derive_tag", "map_value"]
# Should be caught by Pydantic validation, but good practice
raise ValueError(
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
f"Valid options are: {valid_options}"
)
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: Optional[DerivationRegistry] = None,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
Creates a sequence of individual transformation functions based on the
rules defined in the configuration. Returns a single function that
applies these transformations sequentially to an annotation.
Parameters
----------
config : TransformConfig
The configuration object containing the list of transformation rules.
derivation_reg : DerivationRegistry, optional
The derivation registry to use when building `DeriveTagRule`
transformations. Defaults to the global `derivation_registry`.
Returns
-------
SoundEventTransformation
A single function that applies all configured transformations in order.
"""
transforms = [
build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
)
for rule in config.rules
]
return partial(apply_sequence_of_transforms, transforms=transforms)
def apply_sequence_of_transforms(
sound_event_annotation: data.SoundEventAnnotation,
transforms: list[SoundEventTransformation],
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
def load_transformation_config(
path: data.PathLike, field: Optional[str] = None
) -> TransformConfig:
"""Load the transformation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
TransformConfig
The loaded and validated transformation configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
"""
return load_config(path=path, schema=TransformConfig, field=field)
def load_transformation_from_config(
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: Optional[DerivationRegistry] = None,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
This is a convenience function that combines loading the configuration
and building the final callable transformation function that applies
all rules sequentially.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
SoundEventTransformation
The final composite transformation function ready to be used.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
KeyError
If required term keys or derivation keys specified in the config
are not found during the build process.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function specified in the config
fails.
"""
config = load_transformation_config(path=path, field=field)
return build_transformation_from_config(
config,
derivation_registry=derivation_registry,
)
def register_derivation(
key: str,
derivation: Derivation,
derivation_registry: Optional[DerivationRegistry] = None,
) -> None:
"""Register a new derivation function in the global registry.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
derivation_registry : DerivationRegistry, optional
The registry instance to register the derivation function with.
Defaults to the global `derivation_registry`.
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
derivation_registry = derivation_registry or default_derivation_registry
derivation_registry.register(key, derivation)

View File

@ -44,7 +44,7 @@ AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
name: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
@ -140,7 +140,7 @@ def combine_clip_annotations(
class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
augmentation_type: Literal["add_echo"] = "add_echo"
name: Literal["add_echo"] = "add_echo"
probability: float = 0.2
max_delay: float = 0.005
min_weight: float = 0.0
@ -187,7 +187,7 @@ def add_echo(
class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
augmentation_type: Literal["scale_volume"] = "scale_volume"
name: Literal["scale_volume"] = "scale_volume"
probability: float = 0.2
min_scaling: float = 0.0
max_scaling: float = 2.0
@ -214,7 +214,7 @@ def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
class WarpAugmentationConfig(BaseConfig):
augmentation_type: Literal["warp"] = "warp"
name: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04
@ -296,7 +296,7 @@ def warp_spectrogram(
class TimeMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_time"] = "mask_time"
name: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05
max_masks: int = 3
@ -353,7 +353,7 @@ def mask_time(
class FrequencyMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_freq"] = "mask_freq"
name: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
@ -414,7 +414,7 @@ AudioAugmentationConfig = Annotated[
MixAugmentationConfig,
EchoAugmentationConfig,
],
Field(discriminator="augmentation_type"),
Field(discriminator="name"),
]
@ -425,7 +425,7 @@ SpectrogramAugmentationConfig = Annotated[
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="augmentation_type"),
Field(discriminator="name"),
]
AugmentationConfig = Annotated[
@ -437,7 +437,7 @@ AugmentationConfig = Annotated[
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="augmentation_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of individual augmentation config."""
@ -485,7 +485,7 @@ def build_augmentation_from_config(
audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]:
"""Factory function to build a single augmentation from its config."""
if config.augmentation_type == "mix_audio":
if config.name == "mix_audio":
if audio_source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
@ -500,31 +500,31 @@ def build_augmentation_from_config(
max_weight=config.max_weight,
)
if config.augmentation_type == "add_echo":
if config.name == "add_echo":
return AddEcho(
max_delay=int(config.max_delay * samplerate),
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.augmentation_type == "scale_volume":
if config.name == "scale_volume":
return ScaleVolume(
max_scaling=config.max_scaling,
min_scaling=config.min_scaling,
)
if config.augmentation_type == "warp":
if config.name == "warp":
return WarpSpectrogram(
delta=config.delta,
)
if config.augmentation_type == "mask_time":
if config.name == "mask_time":
return MaskTime(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
if config.augmentation_type == "mask_freq":
if config.name == "mask_freq":
return MaskFrequency(
max_perc=config.max_perc,
max_masks=config.max_masks,

View File

@ -67,8 +67,7 @@ class TargetProtocol(Protocol):
This protocol outlines the standard attributes and methods for an object
that encapsulates the complete, configured process for handling sound event
annotations (both tags and geometry). It defines how to:
- Filter relevant annotations.
- Transform annotation tags.
- Select relevant annotations.
- Encode an annotation into a specific target class name.
- Decode a class name back into representative tags.
- Extract a target reference position from an annotation's geometry (ROI).
@ -121,26 +120,6 @@ class TargetProtocol(Protocol):
"""
...
def transform(
self,
sound_event: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
"""Apply tag transformations to an annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation whose tags should be transformed.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags. Implementations
should return the original annotation object if no transformations
were configured.
"""
...
def encode_class(
self,
sound_event: data.SoundEventAnnotation,

View File

@ -18,9 +18,7 @@ from batdetect2.targets import (
build_targets,
call_type,
)
from batdetect2.targets.classes import ClassesConfig, TargetClass
from batdetect2.targets.filtering import FilterConfig, FilterRule
from batdetect2.targets.terms import TagInfo
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.train.clips import build_clipper
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import (
@ -365,43 +363,37 @@ def sample_audio_loader() -> AudioLoader:
@pytest.fixture
def bat_tag() -> TagInfo:
return TagInfo(key="class", value="bat")
def bat_tag() -> data.Tag:
return data.Tag(key="class", value="bat")
@pytest.fixture
def noise_tag() -> TagInfo:
return TagInfo(key="class", value="noise")
def noise_tag() -> data.Tag:
return data.Tag(key="class", value="noise")
@pytest.fixture
def myomyo_tag() -> TagInfo:
return TagInfo(key="species", value="Myotis myotis")
def myomyo_tag() -> data.Tag:
return data.Tag(key="species", value="Myotis myotis")
@pytest.fixture
def pippip_tag() -> TagInfo:
return TagInfo(key="species", value="Pipistrellus pipistrellus")
def pippip_tag() -> data.Tag:
return data.Tag(key="species", value="Pipistrellus pipistrellus")
@pytest.fixture
def sample_target_config(
bat_tag: TagInfo,
noise_tag: TagInfo,
myomyo_tag: TagInfo,
pippip_tag: TagInfo,
bat_tag: data.Tag,
myomyo_tag: data.Tag,
pippip_tag: data.Tag,
) -> TargetConfig:
return TargetConfig(
filtering=FilterConfig(
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
),
classes=ClassesConfig(
classes=[
TargetClass(name="pippip", tags=[pippip_tag]),
TargetClass(name="myomyo", tags=[myomyo_tag]),
],
generic_class=[bat_tag],
),
detection_target=TargetClassConfig(name="bat", tags=[bat_tag]),
classification_targets=[
TargetClassConfig(name="pippip", tags=[pippip_tag]),
TargetClassConfig(name="myomyo", tags=[myomyo_tag]),
],
)

View File

@ -0,0 +1,516 @@
import textwrap
import pytest
import yaml
from pydantic import TypeAdapter
from soundevent import data
from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
)
def build_condition_from_str(content):
content = textwrap.dedent(content)
content = yaml.safe_load(content)
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
return build_sound_event_condition(config)
def test_has_tag(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_tag
tag:
key: species
value: Myotis myotis
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert not condition(sound_event_annotation)
def test_has_all_tags(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_all_tags
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="sex", value="Female"), # type: ignore
],
)
assert condition(sound_event_annotation)
def test_has_any_tags(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_any_tag
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Social"), # type: ignore
],
)
assert not condition(sound_event_annotation)
def test_not(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: not
condition:
name: has_tag
tag:
key: species
value: Myotis myotis
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert not condition(sound_event_annotation)
def test_duration(recording: data.Recording):
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
),
)
se2 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
),
)
se3 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
),
)
condition = build_condition_from_str("""
name: duration
operator: lt
seconds: 2
""")
assert condition(se1)
assert not condition(se2)
assert not condition(se3)
condition = build_condition_from_str("""
name: duration
operator: lte
seconds: 2
""")
assert condition(se1)
assert condition(se2)
assert not condition(se3)
condition = build_condition_from_str("""
name: duration
operator: gt
seconds: 2
""")
assert not condition(se1)
assert not condition(se2)
assert condition(se3)
condition = build_condition_from_str("""
name: duration
operator: gte
seconds: 2
""")
assert not condition(se1)
assert condition(se2)
assert condition(se3)
condition = build_condition_from_str("""
name: duration
operator: eq
seconds: 2
""")
assert not condition(se1)
assert condition(se2)
assert not condition(se3)
def test_frequency(recording: data.Recording):
se12 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
),
)
se13 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
),
)
se14 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
),
)
se24 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
),
)
se34 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
),
)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: lt
hertz: 300
""")
assert condition(se12)
assert not condition(se13)
assert not condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: lte
hertz: 300
""")
assert condition(se12)
assert condition(se13)
assert not condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: gt
hertz: 300
""")
assert not condition(se12)
assert not condition(se13)
assert condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: gte
hertz: 300
""")
assert not condition(se12)
assert condition(se13)
assert condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: eq
hertz: 300
""")
assert not condition(se12)
assert condition(se13)
assert not condition(se14)
# LOW
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: lt
hertz: 200
""")
assert condition(se14)
assert not condition(se24)
assert not condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: lte
hertz: 200
""")
assert condition(se14)
assert condition(se24)
assert not condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: gt
hertz: 200
""")
assert not condition(se14)
assert not condition(se24)
assert condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: gte
hertz: 200
""")
assert not condition(se14)
assert condition(se24)
assert condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
assert not condition(se14)
assert condition(se24)
assert not condition(se34)
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 3]),
recording=recording,
)
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeStamp(coordinates=3),
recording=recording,
)
)
assert not condition(se)
def test_has_tags_fails_if_empty():
with pytest.raises(ValueError):
build_condition_from_str("""
name: has_tags
tags: []
""")
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_duration_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: duration
operator: eq
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_all_of(recording: data.Recording):
condition = build_condition_from_str("""
name: all_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert not condition(se)
def test_any_of(recording: data.Recording):
condition = build_condition_from_str("""
name: any_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert condition(se)

View File

@ -3,26 +3,15 @@ from typing import Callable
from uuid import uuid4
import pytest
from pydantic import ValidationError
from soundevent import data
from soundevent.terms import get_term
from batdetect2.targets.classes import (
DEFAULT_SPECIES_LIST,
ClassesConfig,
TargetClass,
_get_default_class_name,
_get_default_classes,
build_generic_class_tags,
TargetClassConfig,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
is_target_class,
load_classes_config,
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.terms import TagInfo
@pytest.fixture
@ -33,8 +22,8 @@ def sample_annotation(
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
data.Tag(key="species", value="Pipistrellus pipistrellus"),
data.Tag(key="quality", value="Good"),
],
)
@ -51,291 +40,71 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
return factory
def test_target_class_creation():
target_class = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
assert target_class.name == "pippip"
assert target_class.tags[0].key == "species"
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
assert target_class.match_type == "all"
def test_classes_config_creation():
target_class = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
config = ClassesConfig(classes=[target_class])
assert len(config.classes) == 1
assert config.classes[0].name == "pippip"
def test_classes_config_unique_names():
target_class1 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
name="myodau",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
)
ClassesConfig(classes=[target_class1, target_class2]) # No error
def test_classes_config_non_unique_names():
target_class1 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
)
with pytest.raises(ValidationError):
ClassesConfig(classes=[target_class1, target_class2])
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
config = load_classes_config(temp_yaml_path)
assert len(config.classes) == 1
assert config.classes[0].name == "pippip"
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: pippip
tags:
- key: species
value: Myotis daubentonii
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(ValidationError):
load_classes_config(temp_yaml_path)
def test_is_target_class_match_all(
sample_annotation: data.SoundEventAnnotation,
):
tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
}
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any(
sample_annotation: data.SoundEventAnnotation,
):
tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
}
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=False) is False
def test_get_class_names_from_config():
target_class1 = TargetClass(
target_class1 = TargetClassConfig(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
target_class2 = TargetClassConfig(
name="myodau",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
tags=[data.Tag(key="species", value="Myotis daubentonii")],
)
config = ClassesConfig(classes=[target_class1, target_class2])
names = get_class_names_from_config(config)
names = get_class_names_from_config([target_class1, target_class2])
assert names == ["pippip", "myodau"]
def test_build_encoder_from_config(
sample_annotation: data.SoundEventAnnotation,
):
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
)
]
)
encoder = build_sound_event_encoder(config)
classes = [
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
)
]
encoder = build_sound_event_encoder(classes)
result = encoder(sample_annotation)
assert result == "pippip"
config = ClassesConfig(classes=[])
encoder = build_sound_event_encoder(config)
classes = []
encoder = build_sound_event_encoder(classes)
result = encoder(sample_annotation)
assert result is None
def test_load_encoder_from_config_valid(
sample_annotation: data.SoundEventAnnotation,
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
encoder = load_encoder_from_config(temp_yaml_path)
# We cannot directly compare the function, so we test it.
result = encoder(sample_annotation) # type: ignore
assert result == "pippip"
def test_load_encoder_from_config_invalid(
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
classes:
- name: pippip
tags:
- key: invalid_key
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(KeyError):
load_encoder_from_config(temp_yaml_path)
def test_get_default_class_name():
assert _get_default_class_name("Myotis daubentonii") == "myodau"
def test_get_default_classes():
default_classes = _get_default_classes()
assert len(default_classes) == len(DEFAULT_SPECIES_LIST)
first_class = default_classes[0]
assert isinstance(first_class, TargetClass)
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
assert first_class.tags[0].key == "class"
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
def test_build_decoder_from_config():
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
output_tags=[TagInfo(key="call_type", value="Echolocation")],
)
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(config)
classes = [
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
assign_tags=[data.Tag(key="call_type", value="Echolocation")],
)
]
decoder = build_sound_event_decoder(classes)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("event")
assert tags[0].value == "Echolocation"
# Test when output_tags is None, should fall back to tags
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
)
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(config)
classes = [
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
)
]
decoder = build_sound_event_decoder(classes)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("species")
assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
decoder = build_sound_event_decoder(classes, raise_on_unmapped=True)
with pytest.raises(ValueError):
decoder("unknown_class")
# Test raise_on_unmapped=False
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
decoder = build_sound_event_decoder(classes, raise_on_unmapped=False)
tags = decoder("unknown_class")
assert len(tags) == 0
def test_load_decoder_from_config_valid(
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
output_tags:
- key: call_type
value: Echolocation
generic_class:
- key: order
value: Chiroptera
"""
temp_yaml_path = create_temp_yaml(yaml_content)
decoder = load_decoder_from_config(
temp_yaml_path,
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("call_type")
assert tags[0].value == "Echolocation"
def test_build_generic_class_tags_from_config():
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
)
],
generic_class=[
TagInfo(key="order", value="Chiroptera"),
TagInfo(key="call_type", value="Echolocation"),
],
)
generic_tags = build_generic_class_tags(config)
assert len(generic_tags) == 2
assert generic_tags[0].term == get_term("order")
assert generic_tags[0].value == "Chiroptera"
assert generic_tags[1].term == get_term("call_type")
assert generic_tags[1].value == "Echolocation"

View File

@ -1,210 +0,0 @@
from pathlib import Path
from typing import Callable, List, Set
import pytest
from soundevent import data
from batdetect2.targets import build_targets
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
build_filter_from_rule,
build_sound_event_filter,
contains_tags,
does_not_have_tags,
equal_tags,
has_any_tag,
load_filter_config,
load_filter_from_config,
)
from batdetect2.targets.terms import TagInfo, generic_class
@pytest.fixture
def create_annotation(
sound_event: data.SoundEvent,
) -> Callable[[List[str]], data.SoundEventAnnotation]:
"""Helper function to create a SoundEventAnnotation with given tags."""
def factory(tags: List[str]) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(
term=generic_class,
value=tag,
)
for tag in tags
],
)
return factory
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
"""Helper function to create a set of data.Tag objects from a list of strings."""
return {
data.Tag(
term=generic_class,
value=tag,
)
for tag in tags
}
def test_has_any_tag(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag3"])
assert has_any_tag(annotation, tags) is True
annotation = create_annotation(["tag2", "tag4"])
tags = create_tag_set(["tag1", "tag3"])
assert has_any_tag(annotation, tags) is False
def test_contains_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2", "tag3"])
tags = create_tag_set(["tag1", "tag2"])
assert contains_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag2", "tag3"])
assert contains_tags(annotation, tags) is False
def test_does_not_have_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag3", "tag4"])
assert does_not_have_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag3"])
assert does_not_have_tags(annotation, tags) is False
def test_equal_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag2"])
assert equal_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2", "tag3"])
tags = create_tag_set(["tag1", "tag2"])
assert equal_tags(annotation, tags) is False
def test_build_filter_from_rule():
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_any)
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_all)
rule_exclude = FilterRule(
match_type="exclude", tags=[TagInfo(value="tag1")]
)
build_filter_from_rule(rule_exclude)
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_equal)
with pytest.raises(ValueError):
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
build_filter_from_rule(
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
)
def test_build_filter_from_config(create_annotation):
config = FilterConfig(
rules=[
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
]
)
filter_from_config = build_sound_event_filter(config)
annotation_pass = create_annotation(["tag1", "tag2"])
assert filter_from_config(annotation_pass)
annotation_fail = create_annotation(["tag1"])
assert not filter_from_config(annotation_fail)
def test_load_filter_config(tmp_path: Path):
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag1
"""
)
config = load_filter_config(test_config_path)
assert isinstance(config, FilterConfig)
assert len(config.rules) == 1
rule = config.rules[0]
assert rule.match_type == "any"
assert len(rule.tags) == 1
assert rule.tags[0].value == "tag1"
def test_load_filter_from_config(tmp_path: Path, create_annotation):
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag1
"""
)
filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation)
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag2
"""
)
filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation) is False
def test_default_filtering_over_example_dataset(
example_annotations: List[data.ClipAnnotation],
):
targets = build_targets()
clip1 = example_annotations[0]
clip2 = example_annotations[1]
clip3 = example_annotations[2]
assert (
sum(
[targets.filter(sound_event) for sound_event in clip1.sound_events]
)
== 9
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip2.sound_events]
)
== 15
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip3.sound_events]
)
== 20
)

View File

@ -8,6 +8,11 @@ from batdetect2.preprocess import (
build_preprocessor,
)
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.preprocess.spectrogram import (
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
SpectrogramConfig,
)
from batdetect2.targets.rois import (
DEFAULT_ANCHOR,
DEFAULT_FREQUENCY_SCALE,
@ -548,14 +553,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
# Instantiate the mapper.
preprocessor = build_preprocessor(
PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": False,
}
}
)
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
)
audio_loader = build_audio_loader()
mapper = PeakEnergyBBoxMapper(
@ -597,14 +595,13 @@ def test_build_roi_mapper_for_anchor_bbox():
def test_build_roi_mapper_for_peak_energy_bbox():
# Given
preproc_config = PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": True,
"scale": "dB",
}
}
preproc_config = PreprocessingConfig(
spectrogram=SpectrogramConfig(
transforms=[
ScaleAmplitudeConfig(scale="db"),
SpectralMeanSubstractionConfig(),
]
),
)
config = PeakEnergyBBoxMapperConfig(
loading_buffer=0.99,

View File

@ -11,25 +11,34 @@ def test_can_override_default_roi_mapper_per_class(
recording: data.Recording,
):
yaml_content = """
detection_target:
name: bat
match_if:
name: has_tag
tag:
key: order
value: Chiroptera
assign_tags:
- key: order
value: Chiroptera
classification_targets:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
"""
config_path = create_temp_yaml(yaml_content)
@ -65,25 +74,34 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
recording,
):
yaml_content = """
detection_target:
name: bat
match_if:
name: has_tag
tag:
key: order
value: Chiroptera
assign_tags:
- key: order
value: Chiroptera
classification_targets:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
"""
config_path = create_temp_yaml(yaml_content)

View File

@ -1,17 +0,0 @@
import pytest
from batdetect2.targets import terms
from batdetect2.targets.terms import TagInfo
def test_tag_info_and_get_tag_from_info():
tag_info = TagInfo(value="Myotis myotis", key="event")
tag = terms.get_tag_from_info(tag_info)
assert tag.value == "Myotis myotis"
assert tag.term == terms.call_type
def test_get_tag_from_info_key_not_found():
tag_info = TagInfo(value="test", key="non_existent_key")
with pytest.raises(KeyError):
terms.get_tag_from_info(tag_info)

View File

@ -1,360 +0,0 @@
from pathlib import Path
import pytest
from soundevent import data, terms
from batdetect2.targets import (
DeriveTagRule,
MapValueRule,
ReplaceRule,
TagInfo,
TransformConfig,
build_transformation_from_config,
)
from batdetect2.targets.transform import (
DerivationRegistry,
build_transform_from_rule,
)
@pytest.fixture
def derivation_registry():
return DerivationRegistry()
@pytest.fixture
def term1() -> data.Term:
term = data.Term(label="Term 1", definition="unknown", name="test:term1")
terms.add_term(term, key="term1", force=True)
return term
@pytest.fixture
def term2() -> data.Term:
term = data.Term(label="Term 2", definition="unknown", name="test:term2")
terms.add_term(term, key="term2", force=True)
return term
@pytest.fixture
def term3() -> data.Term:
term = data.Term(label="Term 3", definition="unknown", name="test:term3")
terms.add_term(term, key="term3", force=True)
return term
@pytest.fixture
def annotation(
sound_event: data.SoundEvent,
term1: data.Term,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
)
@pytest.fixture
def annotation2(
sound_event: data.SoundEvent,
term2: data.Term,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event, tags=[data.Tag(term=term2, value="value2")]
)
def test_map_value_rule(annotation: data.SoundEventAnnotation):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value2"
def test_map_value_rule_no_match(annotation: data.SoundEventAnnotation):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"other_value": "value2"},
)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value1"
def test_replace_rule(annotation: data.SoundEventAnnotation, term2: data.Term):
rule = ReplaceRule(
rule_type="replace",
original=TagInfo(key="term1", value="value1"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
def test_replace_rule_no_match(
annotation: data.SoundEventAnnotation,
term1: data.Term,
term2: data.Term,
):
rule = ReplaceRule(
rule_type="replace",
original=TagInfo(key="term1", value="wrong_value"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].term != term2
assert transformed_annotation.tags[0].value == "value1"
def test_build_transformation_from_config(
annotation: data.SoundEventAnnotation,
annotation2: data.SoundEventAnnotation,
term1: data.Term,
term2: data.Term,
term3: data.Term,
):
config = TransformConfig(
rules=[
MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
),
ReplaceRule(
rule_type="replace",
original=TagInfo(key="term2", value="value2"),
replacement=TagInfo(key="term3", value="value3"),
),
]
)
transform = build_transformation_from_config(config)
transformed_annotation = transform(annotation)
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].term != term2
assert transformed_annotation.tags[0].value == "value2"
transformed_annotation = transform(annotation2)
assert transformed_annotation.tags[0].term == term3
assert transformed_annotation.tags[0].value == "value3"
def test_derive_tag_rule(
annotation: data.SoundEventAnnotation,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
)
transform_fn = build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_derived"
def test_derive_tag_rule_keep_source_false(
annotation: data.SoundEventAnnotation,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
keep_source=False,
)
transform_fn = build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 1
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1_derived"
def test_derive_tag_rule_target_term(
annotation: data.SoundEventAnnotation,
derivation_registry: DerivationRegistry,
term1: data.Term,
term2: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
target_term_key="term2",
)
transform_fn = build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term2
assert transformed_annotation.tags[1].value == "value1_derived"
def test_derive_tag_rule_import_derivation(
annotation: data.SoundEventAnnotation,
term1: data.Term,
tmp_path: Path,
):
# Create a dummy derivation function in a temporary file
derivation_module_path = (
tmp_path / "temp_derivation.py"
) # Changed to /tmp since /home/santiago is not writable
derivation_module_path.write_text(
"""
def my_imported_derivation(x: str) -> str:
return x + "_imported"
"""
)
# Ensure the temporary file is importable by adding its directory to sys.path
import sys
sys.path.insert(0, str(tmp_path))
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="temp_derivation.my_imported_derivation",
import_derivation=True,
)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_imported"
# Clean up the temporary file and sys.path
sys.path.remove(str(tmp_path))
def test_derive_tag_rule_invalid_derivation():
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="nonexistent_derivation",
)
with pytest.raises(KeyError):
build_transform_from_rule(rule)
def test_build_transform_from_rule_invalid_rule_type():
class InvalidRule:
rule_type = "invalid"
rule = InvalidRule() # type: ignore
with pytest.raises(ValueError):
build_transform_from_rule(rule) # type: ignore
def test_map_value_rule_target_term(
annotation: data.SoundEventAnnotation,
term2: data.Term,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
target_term_key="term2",
)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
def test_map_value_rule_target_term_none(
annotation: data.SoundEventAnnotation,
term1: data.Term,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
target_term_key=None,
)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value2"
def test_derive_tag_rule_target_term_none(
annotation: data.SoundEventAnnotation,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
target_term_key=None,
)
transform_fn = build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_derived"
def test_build_transformation_from_config_empty(
annotation: data.SoundEventAnnotation,
):
config = TransformConfig(rules=[])
transform = build_transformation_from_config(config)
transformed_annotation = transform(annotation)
assert transformed_annotation == annotation

View File

@ -5,7 +5,6 @@ from soundevent import data
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig
from batdetect2.targets.terms import TagInfo
from batdetect2.train.labels import generate_heatmaps
recording = data.Recording(
@ -26,7 +25,7 @@ clip = data.Clip(
def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig,
pippip_tag: TagInfo,
pippip_tag: data.Tag,
):
config = sample_target_config.model_copy(
update=dict(