Compare commits

...

8 Commits

Author SHA1 Message Date
mbsantiago
cd4955d4f3 Eval 2025-09-08 22:04:30 +01:00
mbsantiago
c73984b213 Small fixes 2025-09-08 18:35:02 +01:00
mbsantiago
d8d2e5a2c2 Remove preprocessing modules 2025-09-08 18:11:58 +01:00
mbsantiago
b056d7d28d Make sure training is still working 2025-09-08 18:03:56 +01:00
mbsantiago
95a884ea16 Update tests 2025-09-08 18:00:17 +01:00
mbsantiago
b7ae526071 Big changes in data module 2025-09-08 17:50:25 +01:00
mbsantiago
cf6d0d1ccc Remove stale tests 2025-09-07 11:03:46 +01:00
mbsantiago
709b6355c2 torch.multiprocessing didn't work, returning to serial processing 2025-09-01 11:27:23 +01:00
52 changed files with 2118 additions and 4647 deletions

View File

@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
**Example YAML Configuration (e.g., inside a filter rule):**
```yaml
# ... inside a filtering configuration section ...

View File

@ -1,44 +1,47 @@
targets:
classes:
classes:
- name: myomys
tags:
- value: Myotis mystacinus
- name: pippip
tags:
- value: Pipistrellus pipistrellus
- name: eptser
tags:
- value: Eptesicus serotinus
- name: rhifer
tags:
- value: Rhinolophus ferrumequinum
roi:
name: anchor_bbox
anchor: top-left
generic_class:
detection_target:
name: bat
match_if:
name: all_of
conditions:
- name: has_tag
tag: { key: event, value: Echolocation }
- name: not
condition:
name: has_tag
tag: { key: class, value: Unknown }
assign_tags:
- key: class
value: Bat
filtering:
rules:
- match_type: all
tags:
- key: event
value: Echolocation
- match_type: exclude
classification_targets:
- name: myomys
tags:
- key: class
value: Unknown
value: Myotis mystacinus
- name: pippip
tags:
- key: class
value: Pipistrellus pipistrellus
- name: eptser
tags:
- key: class
value: Eptesicus serotinus
- name: rhifer
tags:
- key: class
value: Rhinolophus ferrumequinum
roi:
name: anchor_bbox
anchor: top-left
preprocess:
audio:
resample:
samplerate: 256000
resample:
enabled: True
method: "poly"
scale: false
center: true
duration: null
spectrogram:
stft:
@ -48,66 +51,66 @@ preprocess:
frequencies:
max_freq: 120000
min_freq: 10000
pcen:
size:
height: 128
resize_factor: 0.5
transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
scale: "amplitude"
size:
height: 128
resize_factor: 0.5
spectral_mean_substraction: true
peak_normalize: false
- name: spectral_mean_substraction
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
min_freq: 10000
max_freq: 120000
top_k_per_sec: 200
labels:
sigma: 3
model:
input_height: 128
in_channels: 1
out_channels: 32
encoder:
layers:
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 32
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 64
- block_type: LayerGroup
- name: LayerGroup
layers:
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 128
- block_type: ConvBlock
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
layers:
- block_type: SelfAttention
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 64
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 32
- block_type: LayerGroup
- name: LayerGroup
layers:
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 32
- block_type: ConvBlock
- name: ConvBlock
out_channels: 32
train:
learning_rate: 0.001
t_max: 100
labels:
sigma: 3
trainer:
max_epochs: 40
dataloaders:
train:
batch_size: 8
@ -115,7 +118,7 @@ train:
shuffle: True
val:
batch_size: 8
batch_size: 1
num_workers: 2
loss:
@ -134,32 +137,34 @@ train:
logger:
logger_type: csv
save_dir: outputs/log/
name: logs
# save_dir: outputs/log/
# name: logs
augmentations:
steps:
- augmentation_type: mix_audio
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- augmentation_type: add_echo
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
- augmentation_type: scale_volume
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- augmentation_type: warp
- name: warp
probability: 0.2
delta: 0.04
- augmentation_type: mask_time
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- augmentation_type: mask_freq
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3

View File

@ -92,19 +92,11 @@ clean-build:
clean: clean-build clean-pyc clean-test clean-docs
# Examples
# Preprocess example data.
example-preprocess OPTIONS="":
batdetect2 preprocess \
--base-dir . \
--dataset-field datasets.train \
--config example_data/config.yaml \
{{OPTIONS}} \
example_data/config.yaml example_data/preprocessed
# Train on example data.
example-train OPTIONS="":
batdetect2 train \
--val-dir example_data/preprocessed \
--val-dataset example_data/dataset.yaml \
--config example_data/config.yaml \
{{OPTIONS}} \
example_data/preprocessed
example_data/dataset.yaml

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.8.1",
"soundevent[audio,geometry,plot]>=2.9.1",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",

View File

@ -1,7 +1,7 @@
from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.preprocess import preprocess
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.train import train_command
__all__ = [
@ -9,7 +9,7 @@ __all__ = [
"detect",
"data",
"train_command",
"preprocess",
"evaluate_command",
]

View File

@ -0,0 +1,63 @@
import sys
from pathlib import Path
from typing import Optional
import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
__all__ = ["evaluate_command"]
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--output-dir", type=click.Path())
@click.option("--workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command(
model_path: Path,
test_dataset: Path,
output_dir: Optional[Path] = None,
workers: Optional[int] = None,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(test_dataset)
logger.debug(
"Loaded {num_annotations} test examples",
num_annotations=len(test_annotations),
)
model, train_config = load_model_from_checkpoint(model_path)
df, results = evaluate(
model,
test_annotations,
config=train_config,
num_workers=workers,
)
print(results)
if output_dir:
df.to_csv(output_dir / "results.csv")

View File

@ -1,142 +0,0 @@
import sys
from pathlib import Path
from typing import Optional
import click
import yaml
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.train.preprocess import (
TrainPreprocessConfig,
load_train_preprocessing_config,
preprocess_dataset,
)
__all__ = ["preprocess"]
@cli.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.argument(
"output",
type=click.Path(),
)
@click.option(
"--dataset-field",
type=str,
help=(
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
)
@click.option(
"--config",
type=click.Path(exists=True),
help=(
"Path to the configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
)
@click.option(
"--config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
)
@click.option(
"--num-workers",
type=int,
help=(
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def preprocess(
dataset_config: Path,
output: Path,
base_dir: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
num_workers: Optional[int] = None,
dataset_field: Optional[str] = None,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Starting preprocessing.")
output = Path(output)
logger.info("Will save outputs to {output}", output=output)
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
if config:
logger.info(
"Loading preprocessing config from: {config}", config=config
)
conf = (
load_train_preprocessing_config(config, field=config_field)
if config is not None
else TrainPreprocessConfig()
)
logger.debug(
"Preprocessing config:\n{conf}",
conf=yaml.dump(conf.model_dump()),
)
dataset = load_dataset_from_config(
dataset_config,
field=dataset_field,
base_dir=base_dir,
)
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
preprocess_dataset(
dataset,
conf,
output=output,
max_workers=num_workers,
)

View File

@ -20,6 +20,8 @@ __all__ = ["train_command"]
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@ -34,6 +36,8 @@ def train_command(
train_dataset: Path,
val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None,
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
train_workers: int = 0,
@ -83,4 +87,6 @@ def train_command(
model_path=model_path,
train_workers=train_workers,
val_workers=val_workers,
log_dir=log_dir,
checkpoint_dir=ckpt_dir,
)

View File

@ -0,0 +1,61 @@
from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel
__all__ = [
"Registry",
]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True)
class LogicProtocol(Generic[T_Config, T_Type], Protocol):
"""A generic protocol for the logic classes (conditions or transforms)."""
@classmethod
def from_config(cls, config: T_Config) -> T_Type: ...
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
class Registry(Generic[T_Type]):
"""A generic class to create and manage a registry of items."""
def __init__(self, name: str):
self._name = name
self._registry = {}
def register(self, config_cls: Type[T_Config]):
"""A decorator factory to register a new item."""
fields = config_cls.model_fields
if "name" not in fields:
raise ValueError("Configuration object must have a 'name' field.")
name = fields["name"].default
if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.")
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
self._registry[name] = logic_cls
return logic_cls
return decorator
def build(self, config: BaseModel) -> T_Type:
"""Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009
if name is None:
raise ValueError("Config does not have a name field")
if name not in self._registry:
raise NotImplementedError(
f"No {self._name} with name '{name}' is registered."
)
return self._registry[name].from_config(config)

View File

@ -14,8 +14,9 @@ format-specific loading function to retrieve the annotations as a standard
"""
from pathlib import Path
from typing import Optional, Union
from typing import Annotated, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.data.annotations.aoef import (
@ -42,10 +43,13 @@ __all__ = [
]
AnnotationFormats = Union[
AnnotationFormats = Annotated[
Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
],
Field(discriminator="format"),
]
"""Type Alias representing all supported data source configurations.

View File

@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
from pydantic import BaseModel, Field
from soundevent import data
from batdetect2.targets import get_term_from_key
PathLike = Union[Path, str, os.PathLike]
__all__ = []
@ -92,15 +90,15 @@ def annotation_to_sound_event(
sound_event=sound_event,
tags=[
data.Tag(
term=get_term_from_key(label_key),
key=label_key, # type: ignore
value=annotation.label,
),
data.Tag(
term=get_term_from_key(event_key),
key=event_key, # type: ignore
value=annotation.event,
),
data.Tag(
term=get_term_from_key(individual_key),
key=individual_key, # type: ignore
value=str(annotation.individual),
),
],
@ -125,7 +123,7 @@ def file_annotation_to_clip(
time_expansion=file_annotation.time_exp,
tags=[
data.Tag(
term=get_term_from_key(label_key),
key=label_key, # type: ignore
value=file_annotation.label,
)
],
@ -157,7 +155,8 @@ def file_annotation_to_clip_annotation(
notes=notes,
tags=[
data.Tag(
term=get_term_from_key(label_key), value=file_annotation.label
key=label_key, # type: ignore
value=file_annotation.label,
)
],
sound_events=[

View File

@ -0,0 +1,287 @@
from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence, Union
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
_conditions: Registry[SoundEventCondition] = Registry("condition")
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
@_conditions.register(HasTagConfig)
class HasTag:
def __init__(self, tag: data.Tag):
self.tag = tag
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tag in sound_event_annotation.tags
@classmethod
def from_config(cls, config: HasTagConfig):
return cls(tag=config.tag)
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: List[data.Tag]
@_conditions.register(HasAllTagsConfig)
class HasAllTags:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tags.issubset(sound_event_annotation.tags)
@classmethod
def from_config(cls, config: HasAllTagsConfig):
return cls(tags=config.tags)
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: List[data.Tag]
@_conditions.register(HasAnyTagConfig)
class HasAnyTag:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return bool(self.tags.intersection(sound_event_annotation.tags))
@classmethod
def from_config(cls, config: HasAnyTagConfig):
return cls(tags=config.tags)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
class DurationConfig(BaseConfig):
name: Literal["duration"] = "duration"
operator: Operator
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
if operator == "gte":
return lambda x: x >= value
if operator == "lt":
return lambda x: x < value
if operator == "lte":
return lambda x: x <= value
if operator == "eq":
return lambda x: x == value
raise ValueError(f"Invalid operator {operator}")
@_conditions.register(DurationConfig)
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
self.seconds = seconds
self._comparator = _build_comparator(self.operator, self.seconds)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
start_time, _, end_time, _ = compute_bounds(geometry)
duration = end_time - start_time
return self._comparator(duration)
@classmethod
def from_config(cls, config: DurationConfig):
return cls(operator=config.operator, seconds=config.seconds)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
operator: Operator
hertz: float
@_conditions.register(FrequencyConfig)
class Frequency:
def __init__(
self,
operator: Operator,
boundary: Literal["low", "high"],
hertz: float,
):
self.operator = operator
self.hertz = hertz
self.boundary = boundary
self._comparator = _build_comparator(self.operator, self.hertz)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
# Automatically false if geometry does not have a frequency range
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
return False
_, low_freq, _, high_freq = compute_bounds(geometry)
if self.boundary == "low":
return self._comparator(low_freq)
return self._comparator(high_freq)
@classmethod
def from_config(cls, config: FrequencyConfig):
return cls(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
@_conditions.register(AllOfConfig)
class AllOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return all(c(sound_event_annotation) for c in self.conditions)
@classmethod
def from_config(cls, config: AllOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return cls(conditions)
class AnyOfConfig(BaseConfig):
name: Literal["any_of"] = "any_of"
conditions: List["SoundEventConditionConfig"]
@_conditions.register(AnyOfConfig)
class AnyOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return any(c(sound_event_annotation) for c in self.conditions)
@classmethod
def from_config(cls, config: AnyOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return cls(conditions)
class NotConfig(BaseConfig):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
@_conditions.register(NotConfig)
class Not:
def __init__(self, condition: SoundEventCondition):
self.condition = condition
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return not self.condition(sound_event_annotation)
@classmethod
def from_config(cls, config: NotConfig):
condition = build_sound_event_condition(config.condition)
return cls(condition)
SoundEventConditionConfig = Annotated[
Union[
HasTagConfig,
HasAllTagsConfig,
HasAnyTagConfig,
DurationConfig,
FrequencyConfig,
AllOfConfig,
AnyOfConfig,
NotConfig,
],
Field(discriminator="name"),
]
def build_sound_event_condition(
config: SoundEventConditionConfig,
) -> SoundEventCondition:
return _conditions.build(config)
def filter_clip_annotation(
clip_annotation: data.ClipAnnotation,
condition: SoundEventCondition,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
sound_event
for sound_event in clip_annotation.sound_events
if condition(sound_event)
]
)
)

View File

@ -19,7 +19,7 @@ The core components are:
"""
from pathlib import Path
from typing import Annotated, List, Optional
from typing import List, Optional
from loguru import logger
from pydantic import Field
@ -31,6 +31,17 @@ from batdetect2.data.annotations import (
AnnotationFormats,
load_annotated_dataset,
)
from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
filter_clip_annotation,
)
from batdetect2.data.transforms import (
ApplyAll,
SoundEventTransformConfig,
build_sound_event_transform,
transform_clip_annotation,
)
from batdetect2.targets.terms import data_source
__all__ = [
@ -52,79 +63,68 @@ sources.
class DatasetConfig(BaseConfig):
"""Configuration model defining the structure of a BatDetect2 dataset.
This class is typically loaded from a YAML file and describes the components
of the dataset, including metadata and a list of data sources.
Attributes
----------
name : str
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
description : str
A longer description of the dataset's contents, origin, purpose, etc.
sources : List[AnnotationFormats]
A list defining the different data sources contributing to this
dataset. Each item in the list must conform to one of the Pydantic
models defined in the `AnnotationFormats` type union. The specific
model used for each source is determined by the mandatory `format`
field within the source's configuration, allowing BatDetect2 to use the
correct parser for different annotation styles.
"""
"""Configuration model defining the structure of a BatDetect2 dataset."""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]
sources: List[AnnotationFormats]
sound_event_filter: Optional[SoundEventConditionConfig] = None
sound_event_transforms: List[SoundEventTransformConfig] = Field(
default_factory=list
)
def load_dataset(
dataset: DatasetConfig,
config: DatasetConfig,
base_dir: Optional[Path] = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.
Iterates through each data source specified in the `dataset_config`,
delegates the loading and parsing of that source's annotations to
`batdetect2.data.annotations.load_annotated_dataset` (which handles
different data formats), and aggregates all resulting `ClipAnnotation`
objects into a single flat list.
Parameters
----------
dataset_config : DatasetConfig
The configuration object describing the dataset and its sources.
base_dir : Path, optional
An optional base directory path. If provided, relative paths for
metadata files or data directories within the `dataset_config`'s
sources might be resolved relative to this directory. Defaults to None.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects
from all specified sources.
Raises
------
Exception
Can raise various exceptions during the delegated loading process
(`load_annotated_dataset`) if files are not found, cannot be parsed
according to the specified format, or other I/O errors occur.
"""
"""Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = []
for source in dataset.sources:
condition = (
build_sound_event_condition(config.sound_event_filter)
if config.sound_event_filter is not None
else None
)
transform = (
ApplyAll(
[
build_sound_event_transform(step)
for step in config.sound_event_transforms
]
)
if config.sound_event_transforms
else None
)
for source in config.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
logger.debug(
"Loaded {num_examples} from dataset source '{source_name}'",
num_examples=len(annotated_source.clip_annotations),
source_name=source.name,
)
clip_annotations.extend(
insert_source_tag(clip_annotation, source)
for clip_annotation in annotated_source.clip_annotations
for clip_annotation in annotated_source.clip_annotations:
clip_annotation = insert_source_tag(clip_annotation, source)
if condition is not None:
clip_annotation = filter_clip_annotation(
clip_annotation,
condition,
)
if transform is not None:
clip_annotation = transform_clip_annotation(
clip_annotation,
transform,
)
clip_annotations.append(clip_annotation)
return clip_annotations
@ -161,7 +161,6 @@ def insert_source_tag(
)
# TODO: add documentation
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
return load_config(path=path, schema=DatasetConfig, field=field)

View File

@ -10,16 +10,8 @@ from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events(
dataset: Dataset,
targets: TargetProtocol,
apply_filter: bool = True,
apply_transform: bool = True,
exclude_generic: bool = True,
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset, applying filtering and
transformations.
This generator function processes sound event annotations from a given
dataset, allowing for optional filtering, transformation, and exclusion of
unclassifiable (generic) events based on the provided target definitions.
"""Iterate over sound events in a dataset.
Parameters
----------
@ -29,18 +21,6 @@ def iterate_over_sound_events(
targets : TargetProtocol
An object implementing the `TargetProtocol`, which provides methods
for filtering, transforming, and encoding sound events.
apply_filter : bool, optional
If True, sound events will be filtered using `targets.filter()`.
Only events for which `targets.filter()` returns True will be yielded.
Defaults to True.
apply_transform : bool, optional
If True, sound events will be transformed using `targets.transform()`
before being yielded. Defaults to True.
exclude_generic : bool, optional
If True, sound events that result in a `None` class name after
`targets.encode()` will be excluded. This is typically used to
filter out events that cannot be mapped to a specific target class.
Defaults to True.
Yields
------
@ -63,17 +43,9 @@ def iterate_over_sound_events(
"""
for clip_annotation in dataset:
for sound_event_annotation in clip_annotation.sound_events:
if apply_filter:
if not targets.filter(sound_event_annotation):
continue
if apply_transform:
sound_event_annotation = targets.transform(
sound_event_annotation
)
class_name = targets.encode_class(sound_event_annotation)
if class_name is None and exclude_generic:
continue
yield class_name, sound_event_annotation

View File

@ -0,0 +1,250 @@
from collections.abc import Callable
from typing import Annotated, Dict, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.data.conditions import (
SoundEventCondition,
SoundEventConditionConfig,
build_sound_event_condition,
)
SoundEventTransform = Callable[
[data.SoundEventAnnotation],
data.SoundEventAnnotation,
]
_transforms: Registry[SoundEventTransform] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig):
name: Literal["set_frequency"] = "set_frequency"
boundary: Literal["low", "high"] = "low"
hertz: float
@_transforms.register(SetFrequencyBoundConfig)
class SetFrequencyBound:
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
self.hertz = hertz
self.boundary = boundary
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
if not isinstance(geometry, data.BoundingBox):
return sound_event_annotation
start_time, low_freq, end_time, high_freq = geometry.coordinates
if self.boundary == "low":
low_freq = self.hertz
high_freq = max(high_freq, low_freq)
elif self.boundary == "high":
high_freq = self.hertz
low_freq = min(high_freq, low_freq)
geometry = data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq],
)
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
return sound_event_annotation.model_copy(
update=dict(sound_event=sound_event)
)
@classmethod
def from_config(cls, config: SetFrequencyBoundConfig):
return cls(hertz=config.hertz, boundary=config.boundary)
class ApplyIfConfig(BaseConfig):
name: Literal["apply_if"] = "apply_if"
transform: "SoundEventTransformConfig"
condition: SoundEventConditionConfig
@_transforms.register(ApplyIfConfig)
class ApplyIf:
def __init__(
self,
condition: SoundEventCondition,
transform: SoundEventTransform,
):
self.condition = condition
self.transform = transform
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
if not self.condition(sound_event_annotation):
return sound_event_annotation
return self.transform(sound_event_annotation)
@classmethod
def from_config(cls, config: ApplyIfConfig):
transform = build_sound_event_transform(config.transform)
condition = build_sound_event_condition(config.condition)
return cls(condition=condition, transform=transform)
class ReplaceTagConfig(BaseConfig):
name: Literal["replace_tag"] = "replace_tag"
original: data.Tag
replacement: data.Tag
@_transforms.register(ReplaceTagConfig)
class ReplaceTag:
def __init__(
self,
original: data.Tag,
replacement: data.Tag,
):
self.original = original
self.replacement = replacement
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
tags = []
for tag in sound_event_annotation.tags:
if tag == self.original:
tags.append(self.replacement)
else:
tags.append(tag)
return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod
def from_config(cls, config: ReplaceTagConfig):
return cls(original=config.original, replacement=config.replacement)
class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str
value_mapping: Dict[str, str]
target_key: Optional[str] = None
@_transforms.register(MapTagValueConfig)
class MapTagValue:
def __init__(
self,
tag_key: str,
value_mapping: Dict[str, str],
target_key: Optional[str] = None,
):
self.tag_key = tag_key
self.value_mapping = value_mapping
self.target_key = target_key
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
tags = []
for tag in sound_event_annotation.tags:
if tag.key != self.tag_key:
tags.append(tag)
continue
value = self.value_mapping.get(tag.value)
if value is None:
tags.append(tag)
continue
if self.target_key is None:
tags.append(tag.model_copy(update=dict(value=value)))
else:
tags.append(
data.Tag(
key=self.target_key, # type: ignore
value=value,
)
)
return sound_event_annotation.model_copy(update=dict(tags=tags))
@classmethod
def from_config(cls, config: MapTagValueConfig):
return cls(
tag_key=config.tag_key,
value_mapping=config.value_mapping,
target_key=config.target_key,
)
class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
@_transforms.register(ApplyAllConfig)
class ApplyAll:
def __init__(self, steps: List[SoundEventTransform]):
self.steps = steps
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
for step in self.steps:
sound_event_annotation = step(sound_event_annotation)
return sound_event_annotation
@classmethod
def from_config(cls, config: ApplyAllConfig):
steps = [build_sound_event_transform(step) for step in config.steps]
return cls(steps)
SoundEventTransformConfig = Annotated[
Union[
SetFrequencyBoundConfig,
ReplaceTagConfig,
MapTagValueConfig,
ApplyIfConfig,
ApplyAllConfig,
],
Field(discriminator="name"),
]
def build_sound_event_transform(
config: SoundEventTransformConfig,
) -> SoundEventTransform:
return _transforms.build(config)
def transform_clip_annotation(
clip_annotation: data.ClipAnnotation,
transform: SoundEventTransform,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
transform(sound_event)
for sound_event in clip_annotation.sound_events
]
)
)

View File

@ -0,0 +1,62 @@
from typing import List
import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import MatchEvaluation
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
data = []
for match in matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
sound_event_annotation = match.sound_event_annotation
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
)
if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
compute_bounds(match.pred_geometry)
)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.pred_class,
("pred", "class_score"): match.pred_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df

View File

@ -0,0 +1,100 @@
from typing import List, Optional, Tuple
import pandas as pd
from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.match import match_all_predictions
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import Model
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.postprocess import get_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.train import build_val_loader
def evaluate(
model: Model,
test_annotations: List[data.ClipAnnotation],
config: Optional[FullTrainingConfig] = None,
num_workers: Optional[int] = None,
) -> Tuple[pd.DataFrame, dict]:
config = config or FullTrainingConfig()
audio_loader = build_audio_loader(config.preprocess.audio)
preprocessor = build_preprocessor(config.preprocess)
targets = build_targets(config.targets)
labeller = build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.train.labels,
)
loader = build_val_loader(
test_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train,
num_workers=num_workers,
)
dataset: ValidationDataset = loader.dataset # type: ignore
clip_annotations = []
predictions = []
for batch in loader:
outputs = model.detector(batch.spec)
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
predictions = get_raw_predictions(
outputs,
clips=[
clip_annotation.clip for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,
)
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matches = match_all_predictions(
clip_annotations,
predictions,
targets=targets,
config=config.evaluation.match,
)
df = extract_matches_dataframe(matches)
metrics = [
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(class_names=targets.class_names),
ClassificationAccuracy(class_names=targets.class_names),
]
results = {
name: value
for metric in metrics
for name, value in metric(matches).items()
}
return df, results

View File

@ -1,6 +1,5 @@
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from functools import partial
from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np
@ -9,7 +8,6 @@ from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from torch.multiprocessing import Pool
from batdetect2.configs import BaseConfig
from batdetect2.typing import (
@ -284,7 +282,7 @@ def match_sound_events_and_raw_predictions(
config = config or MatchConfig()
target_sound_events = [
targets.transform(sound_event_annotation)
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
@ -430,17 +428,19 @@ def match_all_predictions(
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
with Pool() as p:
all_matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
return [
match
for clip_annotation, raw_predictions in zip(
clip_annotations,
predictions,
)
for match in match_sound_events_and_raw_predictions(
clip_annotation,
raw_predictions,
targets=targets,
config=config,
),
zip(clip_annotations, predictions),
)
return [match for matches in all_matches for match in matches]
]
@dataclass

View File

@ -29,7 +29,6 @@ provided here.
from typing import List, Optional
import torch
from lightning import LightningModule
from pydantic import Field
from soundevent.data import PathLike
@ -68,7 +67,10 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
from batdetect2.typing.postprocess import (
DetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
@ -102,7 +104,16 @@ __all__ = [
]
class Model(LightningModule):
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
class Model(torch.nn.Module):
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
@ -114,43 +125,39 @@ class Model(LightningModule):
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
config: ModelConfig,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.save_hyperparameters()
self.config = config
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
def build_model(config: Optional[ModelConfig] = None):
config = config or ModelConfig()
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
preprocessor=preprocessor,
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config.model,
)
return Model(
config=config,
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,

View File

@ -56,7 +56,7 @@ __all__ = [
class SelfAttentionConfig(BaseConfig):
block_type: Literal["SelfAttention"] = "SelfAttention"
name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
@ -178,7 +178,7 @@ class SelfAttention(nn.Module):
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""
block_type: Literal["ConvBlock"] = "ConvBlock"
name: Literal["ConvBlock"] = "ConvBlock"
"""Discriminator field indicating the block type."""
out_channels: int
@ -300,7 +300,7 @@ class VerticalConv(nn.Module):
class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock."""
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
@ -390,7 +390,7 @@ class FreqCoordConvDownBlock(nn.Module):
class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock."""
block_type: Literal["StandardConvDown"] = "StandardConvDown"
name: Literal["StandardConvDown"] = "StandardConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
@ -460,7 +460,7 @@ class StandardConvDownBlock(nn.Module):
class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock."""
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
@ -569,7 +569,7 @@ class FreqCoordConvUpBlock(nn.Module):
class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock."""
block_type: Literal["StandardConvUp"] = "StandardConvUp"
name: Literal["StandardConvUp"] = "StandardConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
@ -664,13 +664,13 @@ LayerConfig = Annotated[
SelfAttentionConfig,
"LayerGroupConfig",
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroupConfig(BaseConfig):
block_type: Literal["LayerGroup"] = "LayerGroup"
name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig]
@ -686,7 +686,7 @@ def build_layer_from_config(
parameters derived from the config and the current pipeline state
(`input_height`, `in_channels`).
It uses the `block_type` field within the `config` object to determine
It uses the `name` field within the `config` object to determine
which block class to instantiate.
Parameters
@ -698,7 +698,7 @@ def build_layer_from_config(
config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
by its `block_type` field.
by its `name` field.
Returns
-------
@ -711,11 +711,11 @@ def build_layer_from_config(
Raises
------
NotImplementedError
If the `config.block_type` does not correspond to a known block type.
If the `config.name` does not correspond to a known block type.
ValueError
If parameters derived from the config are invalid for the block.
"""
if config.block_type == "ConvBlock":
if config.name == "ConvBlock":
return (
ConvBlock(
in_channels=in_channels,
@ -727,7 +727,7 @@ def build_layer_from_config(
input_height,
)
if config.block_type == "FreqCoordConvDown":
if config.name == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
@ -740,7 +740,7 @@ def build_layer_from_config(
input_height // 2,
)
if config.block_type == "StandardConvDown":
if config.name == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
@ -752,7 +752,7 @@ def build_layer_from_config(
input_height // 2,
)
if config.block_type == "FreqCoordConvUp":
if config.name == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
@ -765,7 +765,7 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "StandardConvUp":
if config.name == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
@ -777,7 +777,7 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "SelfAttention":
if config.name == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
@ -788,7 +788,7 @@ def build_layer_from_config(
input_height,
)
if config.block_type == "LayerGroup":
if config.name == "LayerGroup":
current_channels = in_channels
current_height = input_height
@ -804,4 +804,4 @@ def build_layer_from_config(
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.block_type}")
raise NotImplementedError(f"Unknown block type {config.name}")

View File

@ -128,7 +128,7 @@ class Bottleneck(nn.Module):
BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""

View File

@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
StandardConvUpConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the decoder sequence. Each item must be a valid block
config including a `block_type` field and necessary parameters like
config including a `name` field and necessary parameters like
`out_channels`. Input channels for each layer are inferred sequentially.
The list must contain at least one layer.
"""
@ -249,9 +249,9 @@ def build_decoder(
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
If `build_layer_from_config` encounters an unknown `name`.
"""
config = config or DEFAULT_DECODER_CONFIG

View File

@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
StandardConvDownConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
An ordered list of configuration objects, each defining one layer or
block in the encoder sequence. Each item must be a valid block config
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
`StandardConvDownConfig`) including a `block_type` field and necessary
`StandardConvDownConfig`) including a `name` field and necessary
parameters like `out_channels`. Input channels for each layer are
inferred sequentially. The list must contain at least one layer.
"""
@ -287,9 +287,9 @@ def build_encoder(
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
If `build_layer_from_config` encounters an unknown `name`.
"""
if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.")

View File

@ -1,6 +1,14 @@
"""Computes spectrograms from audio waveforms with configurable parameters."""
from typing import Annotated, Callable, List, Literal, Optional, Union
from typing import (
Annotated,
Callable,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
import torch
@ -306,7 +314,7 @@ class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: List[SpectrogramTransform] = Field(
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),

View File

@ -1,53 +1,26 @@
"""Main entry point for the BatDetect2 Target Definition subsystem.
This package (`batdetect2.targets`) provides the tools and configurations
necessary to define precisely what the BatDetect2 model should learn to detect,
classify, and localize from audio data. It involves several conceptual steps,
managed through configuration files and culminating in an executable pipeline:
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
3. **Transformation (`.transform`)**: Modifying tags (standardization,
derivation).
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
target position and size representations, and back.
5. **Class Definition (`.classes`)**: Mapping tags to target class names
(encoding) and mapping predicted names back to tags (decoding).
This module exposes the key components for users to configure and utilize this
target definition pipeline, primarily through the `TargetConfig` data structure
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
configured processing steps. The main way to create a functional `Targets`
object is via the `build_targets` or `load_targets` functions.
"""
"""BatDetect2 Target Definition system."""
from collections import Counter
from typing import Iterable, List, Optional, Tuple
from loguru import logger
from pydantic import Field
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.conditions import (
SoundEventCondition,
build_sound_event_condition,
)
from batdetect2.targets.classes import (
ClassesConfig,
DEFAULT_CLASSES,
DEFAULT_GENERIC_CLASS,
SoundEventDecoder,
SoundEventEncoder,
TargetClass,
build_generic_class_tags,
TargetClassConfig,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
load_classes_config,
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
SoundEventFilter,
build_sound_event_filter,
load_filter_config,
load_filter_from_config,
)
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
@ -55,114 +28,53 @@ from batdetect2.targets.rois import (
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.terms import (
TagInfo,
TermInfo,
TermRegistry,
call_type,
default_term_registry,
get_tag_from_info,
get_term_from_key,
individual,
register_term,
)
from batdetect2.targets.transform import (
DerivationRegistry,
DeriveTagRule,
MapValueRule,
ReplaceRule,
SoundEventTransformation,
TransformConfig,
build_transformation_from_config,
default_derivation_registry,
get_derivation,
load_transformation_config,
load_transformation_from_config,
register_derivation,
)
from batdetect2.targets.terms import call_type, individual
from batdetect2.typing.targets import Position, Size, TargetProtocol
__all__ = [
"ClassesConfig",
"DEFAULT_TARGET_CONFIG",
"DeriveTagRule",
"FilterConfig",
"FilterRule",
"MapValueRule",
"AnchorBBoxMapperConfig",
"ROITargetMapper",
"ReplaceRule",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SoundEventTransformation",
"TagInfo",
"TargetClass",
"TargetClassConfig",
"TargetConfig",
"Targets",
"TermInfo",
"TransformConfig",
"build_generic_class_tags",
"build_roi_mapper",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_sound_event_filter",
"build_transformation_from_config",
"call_type",
"get_class_names_from_config",
"get_derivation",
"get_tag_from_info",
"get_term_from_key",
"individual",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
"load_filter_config",
"load_filter_from_config",
"load_target_config",
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
"register_term",
]
class TargetConfig(BaseConfig):
"""Unified configuration for the entire target definition pipeline.
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS)
This model aggregates the configurations for semantic processing (filtering,
transformation, class definition) and geometric processing (ROI mapping).
It serves as the primary input for building a complete `Targets` object
via `build_targets` or `load_targets`.
Attributes
----------
filtering : FilterConfig, optional
Configuration for filtering sound event annotations based on tags.
If None or omitted, no filtering is applied.
transforms : TransformConfig, optional
Configuration for transforming annotation tags
(mapping, derivation, etc.). If None or omitted, no tag transformations
are applied.
classes : ClassesConfig
Configuration defining the specific target classes, their tag matching
rules for encoding, their representative tags for decoding
(`output_tags`), and the definition of the generic class tags.
This section is mandatory.
roi : ROIConfig, optional
Configuration defining how geometric ROIs (e.g., bounding boxes) are
mapped to target representations (reference point, scaled size).
Controls `position`, `time_scale`, `frequency_scale`. If None or
omitted, default ROI mapping settings are used.
"""
filtering: FilterConfig = Field(default_factory=FilterConfig)
transforms: TransformConfig = Field(default_factory=TransformConfig)
classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG
classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
@field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def load_target_config(
path: data.PathLike,
@ -238,8 +150,7 @@ class Targets(TargetProtocol):
roi_mapper: ROITargetMapper,
class_names: list[str],
generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None,
transform_fn: Optional[SoundEventTransformation] = None,
filter_fn: Optional[SoundEventCondition] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
):
"""Initialize the Targets object.
@ -272,7 +183,6 @@ class Targets(TargetProtocol):
self._filter_fn = filter_fn
self._encode_fn = encode_fn
self._decode_fn = decode_fn
self._transform_fn = transform_fn
self._roi_mapper_overrides = roi_mapper_overrides or {}
for class_name in self._roi_mapper_overrides:
@ -344,27 +254,6 @@ class Targets(TargetProtocol):
"""
return self._decode_fn(class_label)
def transform(
self, sound_event: data.SoundEventAnnotation
) -> data.SoundEventAnnotation:
"""Apply the configured tag transformations to an annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation whose tags should be transformed.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags. If no
transformations were configured, the original annotation object is
returned.
"""
if self._transform_fn:
return self._transform_fn(sound_event)
return sound_event
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
@ -430,113 +319,14 @@ class Targets(TargetProtocol):
return self._roi_mapper.decode(position, size)
DEFAULT_CLASSES = [
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
),
TargetClass(
tags=[TagInfo(value="Myotis alcathoe")],
name="myoalc",
),
TargetClass(
tags=[TagInfo(value="Eptesicus serotinus")],
name="eptser",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus nathusii")],
name="pipnat",
),
TargetClass(
tags=[TagInfo(value="Barbastellus barbastellus")],
name="barbar",
),
TargetClass(
tags=[TagInfo(value="Myotis nattereri")],
name="myonat",
),
TargetClass(
tags=[TagInfo(value="Myotis daubentonii")],
name="myodau",
),
TargetClass(
tags=[TagInfo(value="Myotis brandtii")],
name="myobra",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pipistrellus")],
name="pippip",
),
TargetClass(
tags=[TagInfo(value="Myotis bechsteinii")],
name="myobec",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pygmaeus")],
name="pippyg",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus hipposideros")],
name="rhihip",
),
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
name="pleaur",
),
TargetClass(
tags=[TagInfo(value="Nyctalus noctula")],
name="nycnoc",
),
TargetClass(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
]
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
classes=DEFAULT_CLASSES,
generic_class=[TagInfo(value="Bat")],
)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
filtering=FilterConfig(
rules=[
FilterRule(
match_type="all",
tags=[TagInfo(key="event", value="Echolocation")],
),
FilterRule(
match_type="exclude",
tags=[
TagInfo(key="event", value="Feeding"),
TagInfo(key="event", value="Unknown"),
TagInfo(key="event", value="Not Bat"),
],
),
]
),
classes=DEFAULT_CLASSES_CONFIG,
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_GENERIC_CLASS,
roi=AnchorBBoxMapperConfig(),
)
def build_targets(
config: Optional[TargetConfig] = None,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
This factory function takes the unified `TargetConfig` and constructs all
@ -550,13 +340,6 @@ def build_targets(
----------
config : TargetConfig
The loaded and validated unified target configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance to use for resolving term keys. Defaults
to the global `batdetect2.targets.terms.term_registry`.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
`batdetect2.targets.transform.derivation_registry`.
Returns
-------
@ -577,40 +360,18 @@ def build_targets(
lambda: config.to_yaml_string(),
)
filter_fn = (
build_sound_event_filter(
config.filtering,
term_registry=term_registry,
)
if config.filtering
else None
)
encode_fn = build_sound_event_encoder(
config.classes,
term_registry=term_registry,
)
decode_fn = build_sound_event_decoder(
config.classes,
term_registry=term_registry,
)
transform_fn = (
build_transformation_from_config(
config.transforms,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
if config.transforms
else None
)
filter_fn = build_sound_event_condition(config.detection_target.match_if)
encode_fn = build_sound_event_encoder(config.classification_targets)
decode_fn = build_sound_event_decoder(config.classification_targets)
roi_mapper = build_roi_mapper(config.roi)
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(
config.classes,
term_registry=term_registry,
)
class_names = get_class_names_from_config(config.classification_targets)
generic_class_tags = config.detection_target.assign_tags
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classes.classes
for class_config in config.classification_targets
if class_config.roi is not None
}
@ -621,7 +382,6 @@ def build_targets(
class_names=class_names,
roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
roi_mapper_overrides=roi_overrides,
)
@ -629,8 +389,6 @@ def build_targets(
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.
@ -645,11 +403,6 @@ def load_targets(
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
term_registry : TermRegistry, optional
The TermRegistry instance to use. Defaults to the global default.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
Returns
-------
@ -670,11 +423,7 @@ def load_targets(
config_path,
field=field,
)
return build_targets(
config,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
return build_targets(config)
def iterate_encoded_sound_events(
@ -690,8 +439,6 @@ def iterate_encoded_sound_events(
if geometry is None:
continue
sound_event = targets.transform(sound_event)
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)

View File

@ -1,253 +1,172 @@
from collections import Counter
from functools import partial
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple
from typing import Dict, List, Optional
from pydantic import Field, field_validator
from pydantic import Field, PrivateAttr, computed_field, model_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import (
GENERIC_CLASS_KEY,
TagInfo,
TermRegistry,
default_term_registry,
get_tag_from_info,
from batdetect2.configs import BaseConfig
from batdetect2.data.conditions import (
AllOfConfig,
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
NotConfig,
SoundEventCondition,
SoundEventConditionConfig,
build_sound_event_condition,
)
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
__all__ = [
"DEFAULT_SPECIES_LIST",
"build_generic_class_tags",
"build_sound_event_decoder",
"build_sound_event_encoder",
"get_class_names_from_config",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
]
DEFAULT_SPECIES_LIST = [
"Barbastella barbastellus",
"Eptesicus serotinus",
"Myotis alcathoe",
"Myotis bechsteinii",
"Myotis brandtii",
"Myotis daubentonii",
"Myotis mystacinus",
"Myotis nattereri",
"Nyctalus leisleri",
"Nyctalus noctula",
"Pipistrellus nathusii",
"Pipistrellus pipistrellus",
"Pipistrellus pygmaeus",
"Plecotus auritus",
"Plecotus austriacus",
"Rhinolophus ferrumequinum",
"Rhinolophus hipposideros",
]
"""A default list of common bat species names found in the UK."""
class TargetClass(BaseConfig):
"""Defines criteria for encoding annotations and decoding predictions.
Each instance represents one potential output class for the classification
model. It specifies:
1. A unique `name` for the class.
2. The tag conditions (`tags` and `match_type`) an annotation must meet to
be assigned this class name during training data preparation (encoding).
3. An optional, alternative set of tags (`output_tags`) to be used when
converting a model's prediction of this class name back into annotation
tags (decoding).
Attributes
----------
name : str
The unique name assigned to this target class (e.g., 'pippip',
'myodau', 'noise'). This name is used as the label during model
training and is the expected output from the model's prediction.
Should be unique across all TargetClass definitions in a configuration.
tags : List[TagInfo]
A list of one or more tags (defined using `TagInfo`) used to identify
if an existing annotation belongs to this class during encoding (data
preparation for training). The `match_type` attribute determines how
these tags are evaluated.
match_type : Literal["all", "any"], default="all"
Determines how the `tags` list is evaluated during encoding:
- "all": The annotation must have *all* the tags listed to match.
- "any": The annotation must have *at least one* of the tags listed
to match.
output_tags: Optional[List[TagInfo]], default=None
An optional list of tags (defined using `TagInfo`) to be assigned to a
new annotation when the model predicts this class `name`. If `None`
(default), the tags listed in the `tags` field will be used for
decoding. If provided, this list overrides the `tags` field for the
purpose of decoding predictions back into meaningful annotation tags.
This allows, for example, training on broader categories but decoding
to more specific representative tags.
"""
class TargetClassConfig(BaseConfig):
"""Defines a target class of sound events."""
name: str
tags: List[TagInfo] = Field(min_length=1)
match_type: Literal["all", "any"] = Field(default="all")
output_tags: Optional[List[TagInfo]] = None
condition_input: Optional[SoundEventConditionConfig] = Field(
alias="match_if",
default=None,
)
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
assign_tags: List[data.Tag] = Field(default_factory=list)
roi: Optional[ROIMapperConfig] = None
_match_if: SoundEventConditionConfig = PrivateAttr()
def _get_default_classes() -> List[TargetClass]:
"""Generate a list of default target classes.
@model_validator(mode="after")
def _process_tags(self) -> "TargetClassConfig":
if self.tags and self.condition_input:
raise ValueError("Use either 'tags' or 'match_if', not both.")
Returns
-------
List[TargetClass]
A list of TargetClass objects, one for each species in
DEFAULT_SPECIES_LIST. The class names are simplified versions of the
species names.
"""
return [
TargetClass(
name=_get_default_class_name(value),
tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)],
)
for value in DEFAULT_SPECIES_LIST
]
if self.condition_input is not None:
self._match_if = self.condition_input
return self
def _get_default_class_name(species: str) -> str:
"""Generate a default class name from a species name.
Parameters
----------
species : str
The species name (e.g., "Myotis daubentonii").
Returns
-------
str
A simplified class name (e.g., "myodau").
The genus and species names are converted to lowercase,
the first three letters of each are taken, and concatenated.
"""
genus, species = species.strip().split(" ")
return f"{genus.lower()[:3]}{species.lower()[:3]}"
def _get_default_generic_class() -> List[TagInfo]:
"""Generate the default list of TagInfo objects for the generic class.
Provides a default set of tags used to represent the generic "Bat" category
when decoding predictions that didn't match a specific class.
Returns
-------
List[TagInfo]
A list containing default TagInfo objects, typically representing
`call_type: Echolocation` and `order: Chiroptera`.
"""
return [
TagInfo(key="call_type", value="Echolocation"),
TagInfo(key="order", value="Chiroptera"),
]
class ClassesConfig(BaseConfig):
"""Configuration defining target classes and the generic fallback category.
Holds the ordered list of specific target class definitions (`TargetClass`)
and defines the tags representing the generic category for sounds that pass
filtering but do not match any specific class.
The order of `TargetClass` objects in the `classes` list defines the
priority for classification during encoding. The system checks annotations
against these definitions sequentially and assigns the name of the *first*
matching class.
Attributes
----------
classes : List[TargetClass]
An ordered list of specific target class definitions. The order
determines matching priority (first match wins). Defaults to a
standard set of classes via `get_default_classes`.
generic_class : List[TagInfo]
A list of tags defining the "generic" or "unclassified but relevant"
category (e.g., representing a generic 'Bat' call that wasn't
assigned to a specific species). These tags are typically assigned
during decoding when a sound event was detected and passed filtering
but did not match any specific class rule defined in the `classes` list.
Defaults to a standard set of tags via `get_default_generic_class`.
Raises
------
ValueError
If validation fails (e.g., non-unique class names in the `classes`
list).
Notes
-----
- It is crucial that the `name` attribute of each `TargetClass` in the
`classes` list is unique. This configuration includes a validator to
enforce this uniqueness.
- The `generic_class` tags provide a baseline identity for relevant sounds
that don't fit into more specific defined categories.
"""
classes: List[TargetClass] = Field(default_factory=_get_default_classes)
generic_class: List[TagInfo] = Field(
default_factory=_get_default_generic_class
)
@field_validator("classes")
def check_unique_class_names(cls, v: List[TargetClass]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
if self.tags is None:
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
f"Class '{self.name}' must have a 'tags' or 'match_if' rule."
)
return v
self._match_if = HasAllTagsConfig(tags=self.tags)
if not self.assign_tags:
self.assign_tags = self.tags
return self
@computed_field
@property
def match_if(self) -> SoundEventConditionConfig:
return self._match_if
def is_target_class(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
match_all: bool = True,
) -> bool:
"""Check if a sound event annotation matches a set of required tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
required_tags : Set[data.Tag]
A set of `soundevent.data.Tag` objects that define the class criteria.
match_all : bool, default=True
If True, checks if *all* `required_tags` are present in the
annotation's tags (subset check). If False, checks if *at least one*
of the `required_tags` is present (intersection check).
Returns
-------
bool
True if the annotation meets the tag criteria, False otherwise.
"""
annotation_tags = set(sound_event_annotation.tags)
if match_all:
return tags <= annotation_tags
return bool(tags & annotation_tags)
DEFAULT_GENERIC_CLASS = TargetClassConfig(
name="bat",
match_if=AllOfConfig(
conditions=[
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")),
NotConfig(
condition=HasAnyTagConfig(
tags=[
data.Tag(key="event", value="Feeding"),
data.Tag(key="event", value="Unknown"),
data.Tag(key="event", value="Not Bat"),
]
)
),
]
),
assign_tags=[
data.Tag(key="call_type", value="Echolocation"),
data.Tag(key="order", value="Chiroptera"),
],
)
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
DEFAULT_CLASSES = [
TargetClassConfig(
name="myomys",
tags=[data.Tag(key="class", value="Myotis mystacinus")],
),
TargetClassConfig(
name="myoalc",
tags=[data.Tag(key="class", value="Myotis alcathoe")],
),
TargetClassConfig(
name="eptser",
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
),
TargetClassConfig(
name="pipnat",
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
),
TargetClassConfig(
name="barbar",
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
),
TargetClassConfig(
name="myonat",
tags=[data.Tag(key="class", value="Myotis nattereri")],
),
TargetClassConfig(
name="myodau",
tags=[data.Tag(key="class", value="Myotis daubentonii")],
),
TargetClassConfig(
name="myobra",
tags=[data.Tag(key="class", value="Myotis brandtii")],
),
TargetClassConfig(
name="pippip",
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
),
TargetClassConfig(
name="myobec",
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
),
TargetClassConfig(
name="pippyg",
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
),
TargetClassConfig(
name="rhihip",
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClassConfig(
name="nyclei",
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
),
TargetClassConfig(
name="rhifer",
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClassConfig(
name="pleaur",
tags=[data.Tag(key="class", value="Plecotus auritus")],
),
TargetClassConfig(
name="nycnoc",
tags=[data.Tag(key="class", value="Nyctalus noctula")],
),
TargetClassConfig(
name="pleaus",
tags=[data.Tag(key="class", value="Plecotus austriacus")],
),
]
def get_class_names_from_config(configs: List[TargetClassConfig]) -> List[str]:
"""Extract the list of class names from a ClassesConfig object.
Parameters
@ -260,340 +179,60 @@ def get_class_names_from_config(config: ClassesConfig) -> List[str]:
List[str]
An ordered list of unique class names defined in the configuration.
"""
return [class_info.name for class_info in config.classes]
def _encode_with_multiple_classifiers(
sound_event_annotation: data.SoundEventAnnotation,
classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]],
) -> Optional[str]:
"""Encode an annotation by checking against a list of classifiers.
Internal helper function used by the `SoundEventEncoder`. It iterates
through the provided list of (class_name, classifier_function) pairs.
Returns the name associated with the first classifier function that
returns True for the given annotation.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to encode.
classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]]
An ordered list where each tuple contains a class name and a function
that returns True if the annotation matches that class. The order
determines priority.
Returns
-------
str or None
The name of the first matching class, or None if no classifier matches.
"""
for class_name, classifier in classifiers:
if classifier(sound_event_annotation):
return class_name
return None
return [class_info.name for class_info in configs]
def build_sound_event_encoder(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
configs: List[TargetClassConfig],
) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration.
"""Build a sound event encoder function from the classes configuration."""
conditions = {
class_config.name: build_sound_event_condition(class_config.match_if)
for class_config in configs
}
The returned encoder function iterates through the class definitions in the
order specified in the config. It assigns an annotation the name of the
first class definition it matches.
Parameters
----------
config : ClassesConfig
The loaded and validated classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance used to look up term keys specified in the
`TagInfo` objects within the configuration. Defaults to the global
`batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
A callable function that takes a `SoundEventAnnotation` and returns
an optional string representing the matched class name, or None if no
class matches.
Raises
------
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry`.
"""
binary_classifiers = [
(
class_info.name,
partial(
is_target_class,
tags={
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in class_info.tags
},
match_all=class_info.match_type == "all",
),
)
for class_info in config.classes
]
return partial(
_encode_with_multiple_classifiers,
classifiers=binary_classifiers,
)
return SoundEventClassifier(conditions)
def _decode_class(
name: str,
mapping: Dict[str, List[data.Tag]],
raise_on_error: bool = True,
) -> List[data.Tag]:
"""Decode a class name into a list of representative tags using a mapping.
class SoundEventClassifier:
def __init__(self, mapping: Dict[str, SoundEventCondition]):
self.mapping = mapping
Internal helper function used by the `SoundEventDecoder`. Looks up the
provided class `name` in the `mapping` dictionary.
Parameters
----------
name : str
The class name to decode.
mapping : Dict[str, List[data.Tag]]
A dictionary mapping class names to lists of `soundevent.data.Tag`
objects.
raise_on_error : bool, default=True
If True, raises a ValueError if the `name` is not found in the
`mapping`. If False, returns an empty list if the `name` is not found.
Returns
-------
List[data.Tag]
The list of tags associated with the class name, or an empty list if
not found and `raise_on_error` is False.
Raises
------
ValueError
If `name` is not found in `mapping` and `raise_on_error` is True.
"""
if name not in mapping and raise_on_error:
raise ValueError(f"Class {name} not found in mapping.")
if name not in mapping:
return []
return mapping[name]
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> Optional[str]:
for name, condition in self.mapping.items():
if condition(sound_event_annotation):
return name
def build_sound_event_decoder(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
configs: List[TargetClassConfig],
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration.
Creates a callable `SoundEventDecoder` that maps a class name string
back to a list of representative `soundevent.data.Tag` objects based on
the `ClassesConfig`. It uses the `output_tags` field if provided in a
`TargetClass`, otherwise falls back to the `tags` field.
Parameters
----------
config : ClassesConfig
The loaded and validated classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance used to look up term keys. Defaults to the
global `batdetect2.targets.terms.registry`.
raise_on_unmapped : bool, default=False
If True, the returned decoder function will raise a ValueError if asked
to decode a class name that is not in the configuration. If False, it
will return an empty list for unmapped names.
Returns
-------
SoundEventDecoder
A callable function that takes a class name string and returns a list
of `soundevent.data.Tag` objects.
Raises
------
KeyError
If a term key specified in the configuration (`output_tags`, `tags`, or
`generic_class`) is not found in the provided `term_registry`.
"""
mapping = {}
for class_info in config.classes:
tags_to_use = (
class_info.output_tags
if class_info.output_tags is not None
else class_info.tags
)
mapping[class_info.name] = [
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in tags_to_use
]
return partial(
_decode_class,
mapping=mapping,
raise_on_error=raise_on_unmapped,
)
"""Build a sound event decoder function from the classes configuration."""
mapping = {
class_config.name: class_config.assign_tags for class_config in configs
}
return TagDecoder(mapping, raise_on_unknown=raise_on_unmapped)
def build_generic_class_tags(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config.
class TagDecoder:
def __init__(
self,
mapping: Dict[str, List[data.Tag]],
raise_on_unknown: bool = True,
):
self.mapping = mapping
self.raise_on_unknown = raise_on_unknown
Converts the list of `TagInfo` objects defined in `config.generic_class`
into a list of `soundevent.data.Tag` objects using the term registry.
def __call__(self, class_name: str) -> List[data.Tag]:
tags = self.mapping.get(class_name)
Parameters
----------
config : ClassesConfig
The loaded classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance for term lookups. Defaults to the global
`batdetect2.targets.terms.registry`.
if tags is None:
if self.raise_on_unknown:
raise ValueError("Invalid class name")
Returns
-------
List[data.Tag]
The list of fully constructed tags representing the generic class.
tags = []
Raises
------
KeyError
If a term key specified in `config.generic_class` is not found in the
provided `term_registry`.
"""
return [
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in config.generic_class
]
def load_classes_config(
path: data.PathLike,
field: Optional[str] = None,
) -> ClassesConfig:
"""Load the target classes configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
ClassesConfig
The loaded and validated classes configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
"""
return load_config(path, schema=ClassesConfig, field=field)
def load_encoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file.
This is a convenience function that combines loading the `ClassesConfig`
from a file and building the final `SoundEventEncoder` function.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
term_registry : TermRegistry, optional
The TermRegistry instance used for term lookups. Defaults to the
global `batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
The final encoder function ready to classify annotations.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_sound_event_encoder(config, term_registry=term_registry)
def load_decoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file.
This is a convenience function that combines loading the `ClassesConfig`
from a file and building the final `SoundEventDecoder` function.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
term_registry : TermRegistry, optional
The TermRegistry instance used for term lookups. Defaults to the
global `batdetect2.targets.terms.registry`.
raise_on_unmapped : bool, default=False
If True, the returned decoder function will raise a ValueError if asked
to decode a class name that is not in the configuration. If False, it
will return an empty list for unmapped names.
Returns
-------
SoundEventDecoder
The final decoder function ready to convert class names back into tags.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_sound_event_decoder(
config,
term_registry=term_registry,
raise_on_unmapped=raise_on_unmapped,
)
return tags

View File

@ -1,307 +0,0 @@
import logging
from functools import partial
from typing import List, Literal, Optional, Set
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
default_term_registry,
get_tag_from_info,
)
from batdetect2.typing.targets import SoundEventFilter
__all__ = [
"FilterConfig",
"FilterRule",
"build_sound_event_filter",
"build_filter_from_rule",
"load_filter_config",
"load_filter_from_config",
]
logger = logging.getLogger(__name__)
class FilterRule(BaseConfig):
"""Defines a single rule for filtering sound event annotations.
Based on the `match_type`, this rule checks if the tags associated with a
sound event annotation meet certain criteria relative to the `tags` list
defined in this rule.
Attributes
----------
match_type : Literal["any", "all", "exclude", "equal"]
Determines how the `tags` list is used:
- "any": Pass if the annotation has at least one tag from the list.
- "all": Pass if the annotation has all tags from the list (it can
have others too).
- "exclude": Pass if the annotation has none of the tags from the list.
- "equal": Pass if the annotation's tags are exactly the same set as
provided in the list.
tags : List[TagInfo]
A list of tags (defined using TagInfo for configuration) that this
rule operates on.
"""
match_type: Literal["any", "all", "exclude", "equal"]
tags: List[TagInfo]
def has_any_tag(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation has at least one of the specified tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags to look for.
Returns
-------
bool
True if the annotation has one or more tags from the specified set,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return bool(tags & sound_event_tags)
def contains_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation contains all of the specified tags.
The annotation may have additional tags beyond those specified.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags that must all be present in the annotation.
Returns
-------
bool
True if the annotation's tags are a superset of the specified tags,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return tags <= sound_event_tags
def does_not_have_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
):
"""Check if the annotation has none of the specified tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags that must *not* be present in the annotation.
Returns
-------
bool
True if the annotation has zero tags in common with the specified set,
False otherwise.
"""
return not has_any_tag(sound_event_annotation, tags)
def equal_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation's tags are exactly equal to the specified set.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The exact set of tags the annotation must have.
Returns
-------
bool
True if the annotation's tags set is identical to the specified set,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return tags == sound_event_tags
def build_filter_from_rule(
rule: FilterRule,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule.
Parameters
----------
rule : FilterRule
The filter rule configuration object.
Returns
-------
SoundEventFilter
A function that takes a SoundEventAnnotation and returns True if it
passes the rule, False otherwise.
Raises
------
ValueError
If the rule contains an invalid `match_type`.
"""
tag_set = {
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in rule.tags
}
if rule.match_type == "any":
return partial(has_any_tag, tags=tag_set)
if rule.match_type == "all":
return partial(contains_tags, tags=tag_set)
if rule.match_type == "exclude":
return partial(does_not_have_tags, tags=tag_set)
if rule.match_type == "equal":
return partial(equal_tags, tags=tag_set)
raise ValueError(
f"Invalid match type {rule.match_type}. Valid types "
"are: 'any', 'all', 'exclude' and 'equal'"
)
def _passes_all_filters(
sound_event_annotation: data.SoundEventAnnotation,
filters: List[SoundEventFilter],
) -> bool:
"""Check if the annotation passes all provided filters.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
filters : List[SoundEventFilter]
A list of filter functions to apply.
Returns
-------
bool
True if the annotation passes all filters, False otherwise.
"""
for filter_fn in filters:
if not filter_fn(sound_event_annotation):
logging.debug(
f"Sound event annotation {sound_event_annotation.uuid} "
f"excluded due to rule {filter_fn}",
)
return False
return True
class FilterConfig(BaseConfig):
"""Configuration model for defining a list of filter rules.
Attributes
----------
rules : List[FilterRule]
A list of FilterRule objects. An annotation must pass all rules in
this list to be considered valid by the filter built from this config.
"""
rules: List[FilterRule] = Field(default_factory=list)
def build_sound_event_filter(
config: FilterConfig,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object.
Creates individual filter functions for each rule in the configuration
and merges them using AND logic.
Parameters
----------
config : FilterConfig
The configuration object containing the list of filter rules.
Returns
-------
SoundEventFilter
A single callable filter function that applies all defined rules.
"""
filters = [
build_filter_from_rule(rule, term_registry=term_registry)
for rule in config.rules
]
return partial(_passes_all_filters, filters=filters)
def load_filter_config(
path: data.PathLike, field: Optional[str] = None
) -> FilterConfig:
"""Loads the filter configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : Optional[str], optional
If the filter configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
FilterConfig
The loaded and validated filter configuration object.
"""
return load_config(path, schema=FilterConfig, field=field)
def load_filter_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function.
This is a convenience function that combines loading the configuration
and building the final callable filter function.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : Optional[str], optional
If the filter configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
SoundEventFilter
The final merged filter function ready to be used.
"""
config = load_filter_config(path=path, field=field)
return build_sound_event_filter(config, term_registry=term_registry)

View File

@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification.
"""
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
from typing import Annotated, Literal, Optional, Tuple, Union
import numpy as np
from pydantic import Field
@ -30,7 +30,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.targets import Position, Size
from batdetect2.typing.targets import Position, ROITargetMapper, Size
from batdetect2.utils.arrays import spec_to_xarray
__all__ = [
@ -83,73 +83,6 @@ DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the `encode` and `decode` methods required for converting a
`soundevent.data.SoundEvent` into a target representation (a reference
position and a size vector) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions in the `Size` array
returned by `encode` and expected by `decode`.
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent's geometry into a position and size.
Parameters
----------
sound_event : data.SoundEvent
The input sound event, which must have a geometry attribute.
Returns
-------
Tuple[Position, Size]
A tuple containing:
- The reference position as (time, frequency) coordinates.
- A NumPy array with the calculated size dimensions.
Raises
------
ValueError
If the sound event does not have a geometry.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and size
dimensions and reconstructs a geometric representation.
Parameters
----------
position : Position
The reference position (time, frequency).
size : Size
NumPy array containing the size dimensions, matching the order
and meaning specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry, typically a `BoundingBox`.
Raises
------
ValueError
If the `size` array has an unexpected shape or if reconstruction
fails.
"""
...
class AnchorBBoxMapperConfig(BaseConfig):
"""Configuration for `AnchorBBoxMapper`.
@ -475,7 +408,10 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
ROIMapperConfig = Annotated[
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
Union[
AnchorBBoxMapperConfig,
PeakEnergyBBoxMapperConfig,
],
Field(discriminator="name"),
]
"""A discriminated union of all supported ROI mapper configurations.
@ -553,7 +489,7 @@ def _build_bounding_box(
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
Internal helper for `BBoxEncoder.decode`. Calculates the box
coordinates [start_time, low_freq, end_time, high_freq] based on where
the input `pos` (time, freq) is located relative to the box (e.g.,
center, corner).

View File

@ -1,34 +1,11 @@
"""Manages the vocabulary (Terms and Tags) for defining training targets.
"""Manages the vocabulary for defining training targets."""
This module provides the necessary tools to declare, register, and manage the
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
sub-package. It establishes a consistent vocabulary for filtering,
transforming, and classifying sound events based on their annotations (Tags).
The core component is the `TermRegistry`, which maps unique string keys
(aliases) to specific `Term` definitions. This allows users to refer to complex
terms using simple, consistent keys in configuration files and code.
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
programmatically, or loaded from external configuration files (e.g., YAML).
"""
from collections.abc import Mapping
from inspect import getmembers
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from soundevent import data, terms
from batdetect2.configs import load_config
__all__ = [
"call_type",
"individual",
"data_source",
"get_tag_from_info",
"TermInfo",
"TagInfo",
]
# The default key used to reference the 'generic_class' term.
@ -96,430 +73,3 @@ terms.register_term_set(
),
override_existing=True,
)
class TermRegistry(Mapping[str, data.Term]):
"""Manages a registry mapping unique keys to Term definitions.
This class acts as the central repository for the vocabulary of terms
used within the target definition process. It allows registering terms
with simple string keys and retrieving them consistently.
"""
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
"""Initializes the TermRegistry.
Parameters
----------
terms : dict[str, soundevent.data.Term], optional
An optional dictionary of initial key-to-Term mappings
to populate the registry with. Defaults to an empty registry.
"""
self._terms: Dict[str, data.Term] = terms or {}
def __getitem__(self, key: str) -> data.Term:
return self._terms[key]
def __len__(self) -> int:
return len(self._terms)
def __iter__(self):
return iter(self._terms)
def add_term(self, key: str, term: data.Term) -> None:
"""Adds a Term object to the registry with the specified key.
Parameters
----------
key : str
The unique string key to associate with the term.
term : soundevent.data.Term
The soundevent.data.Term object to register.
Raises
------
KeyError
If a term with the provided key already exists in the
registry.
"""
if key in self._terms:
raise KeyError("A term with the provided key already exists.")
self._terms[key] = term
def get_term(self, key: str) -> data.Term:
"""Retrieves a registered term by its unique key.
Parameters
----------
key : str
The unique string key of the term to retrieve.
Returns
-------
soundevent.data.Term
The corresponding soundevent.data.Term object.
Raises
------
KeyError
If no term with the specified key is found, with a
helpful message suggesting listing available keys.
"""
try:
return self._terms[key]
except KeyError as err:
raise KeyError(
"No term found for key "
f"'{key}'. Ensure it is registered or loaded. "
f"Available keys: {', '.join(self.get_keys())}"
) from err
def add_custom_term(
self,
key: str,
name: Optional[str] = None,
uri: Optional[str] = None,
label: Optional[str] = None,
definition: Optional[str] = None,
) -> data.Term:
"""Creates a new Term from attributes and adds it to the registry.
This is useful for defining terms directly in code or when loading
from configuration files where only attributes are provided.
If optional fields (`name`, `label`, `definition`) are not provided,
reasonable defaults are used (`key` for name/label, "Unknown" for
definition).
Parameters
----------
key : str
The unique string key for the new term.
name : str, optional
The name for the new term (defaults to `key`).
uri : str, optional
The URI for the new term (optional).
label : str, optional
The display label for the new term (defaults to `key`).
definition : str, optional
The definition for the new term (defaults to "Unknown").
Returns
-------
soundevent.data.Term
The newly created and registered soundevent.data.Term object.
Raises
------
KeyError
If a term with the provided key already exists.
"""
term = data.Term(
name=name or key,
label=label or key,
uri=uri,
definition=definition or "Unknown",
)
self.add_term(key, term)
return term
def get_keys(self) -> List[str]:
"""Returns a list of all keys currently registered.
Returns
-------
list[str]
A list of strings representing the keys of all registered terms.
"""
return list(self._terms.keys())
def get_terms(self) -> List[data.Term]:
"""Returns a list of all registered terms.
Returns
-------
list[soundevent.data.Term]
A list containing all registered Term objects.
"""
return list(self._terms.values())
def remove_key(self, key: str) -> None:
del self._terms[key]
default_term_registry = TermRegistry(
terms=dict(
[
*getmembers(terms, lambda x: isinstance(x, data.Term)),
("event", call_type),
("species", terms.scientific_name),
("individual", individual),
("data_source", data_source),
(GENERIC_CLASS_KEY, generic_class),
]
)
)
"""The default, globally accessible TermRegistry instance.
It is pre-populated with standard terms from `soundevent.terms` and common
terms defined in this module (`call_type`, `individual`, `generic_class`).
Functions in this module use this registry by default unless another instance
is explicitly passed.
"""
def get_term_from_key(
key: str,
term_registry: Optional[TermRegistry] = None,
) -> data.Term:
"""Convenience function to retrieve a term by key from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
key : str
The unique key of the term to retrieve.
term_registry : TermRegistry, optional
The TermRegistry instance to search in. Defaults to the global
`registry`.
Returns
-------
soundevent.data.Term
The corresponding soundevent.data.Term object.
Raises
------
KeyError
If the key is not found in the specified registry.
"""
term = terms.get_term(key)
if term:
return term
term_registry = term_registry or default_term_registry
return term_registry.get_term(key)
def get_term_keys(
term_registry: TermRegistry = default_term_registry,
) -> List[str]:
"""Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
term_registry : TermRegistry, optional
The TermRegistry instance to query. Defaults to the global `registry`.
Returns
-------
list[str]
A list of strings representing the keys of all registered terms.
"""
return term_registry.get_keys()
def get_terms(
term_registry: TermRegistry = default_term_registry,
) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
term_registry : TermRegistry, optional
The TermRegistry instance to query. Defaults to the global `registry`.
Returns
-------
list[soundevent.data.Term]
A list containing all registered Term objects.
"""
return term_registry.get_terms()
class TagInfo(BaseModel):
"""Represents information needed to define a specific Tag.
This model is typically used in configuration files (e.g., YAML) to
specify tags used for filtering, target class definition, or associating
tags with output classes. It links a tag value to a term definition
via the term's registry key.
Attributes
----------
value : str
The value of the tag (e.g., "Myotis myotis", "Echolocation").
key : str, default="class"
The key (alias) of the term associated with this tag, as
registered in the TermRegistry. Defaults to "class", implying
it represents a classification target label by default.
"""
value: str
key: str = GENERIC_CLASS_KEY
def get_tag_from_info(
tag_info: TagInfo,
term_registry: Optional[TermRegistry] = None,
) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
Looks up the term using the key in the provided `tag_info` from the
specified registry and constructs a Tag object.
Parameters
----------
tag_info : TagInfo
The TagInfo object containing the value and term key.
term_registry : TermRegistry, optional
The TermRegistry instance to use for term lookup. Defaults to the
global `registry`.
Returns
-------
soundevent.data.Tag
A soundevent.data.Tag object corresponding to the input info.
Raises
------
KeyError
If the term key specified in `tag_info.key` is not found
in the registry.
"""
term_registry = term_registry or default_term_registry
term = get_term_from_key(tag_info.key, term_registry=term_registry)
return data.Tag(term=term, value=tag_info.value)
class TermInfo(BaseModel):
"""Represents the definition of a Term within a configuration file.
This model allows users to define custom terms directly in configuration
files (e.g., YAML) which can then be loaded into the TermRegistry.
It mirrors the parameters of `TermRegistry.add_custom_term`.
Attributes
----------
key : str
The unique key (alias) that will be used to register and
reference this term.
label : str, optional
The optional display label for the term. Defaults to `key`
if not provided during registration.
name : str, optional
The optional formal name for the term. Defaults to `key`
if not provided during registration.
uri : str, optional
The optional URI identifying the term (e.g., from a standard
vocabulary).
definition : str, optional
The optional textual definition of the term. Defaults to
"Unknown" if not provided during registration.
"""
key: str
label: Optional[str] = None
name: Optional[str] = None
uri: Optional[str] = None
definition: Optional[str] = None
class TermConfig(BaseModel):
"""Pydantic schema for loading a list of term definitions from config.
This model typically corresponds to a section in a configuration file
(e.g., YAML) containing a list of term definitions to be registered.
Attributes
----------
terms : list[TermInfo]
A list of TermInfo objects, each defining a term to be
registered. Defaults to an empty list.
Examples
--------
Example YAML structure:
```yaml
terms:
- key: species
uri: dwc:scientificName
label: Scientific Name
- key: my_custom_term
name: My Custom Term
definition: Describes a specific project attribute.
# ... more TermInfo definitions
```
"""
terms: List[TermInfo] = Field(default_factory=list)
def load_terms_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them.
Parses a configuration file (e.g., YAML) using the TermConfig schema,
extracts the list of TermInfo definitions, and adds each one as a
custom term to the specified TermRegistry instance.
Parameters
----------
path : data.PathLike
The path to the configuration file.
field : str, optional
Optional key indicating a specific section within the config
file where the 'terms' list is located. If None, expects the
list directly at the top level or within a structure matching
TermConfig schema.
term_registry : TermRegistry, optional
The TermRegistry instance to add the loaded terms to. Defaults to
the global `registry`.
Returns
-------
dict[str, soundevent.data.Term]
A dictionary mapping the keys of the newly added terms to their
corresponding Term objects.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TermConfig schema.
KeyError
If a term key loaded from the config conflicts with a key
already present in the registry.
"""
data = load_config(path, schema=TermConfig, field=field)
return {
info.key: term_registry.add_custom_term(
info.key,
name=info.name,
uri=info.uri,
label=info.label,
definition=info.definition,
)
for info in data.terms
}
def register_term(
key: str, term: data.Term, registry: TermRegistry = default_term_registry
) -> None:
registry.add_term(key, term)

View File

@ -1,708 +0,0 @@
import importlib
from functools import partial
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
get_term_from_key,
)
__all__ = [
"DerivationRegistry",
"DeriveTagRule",
"MapValueRule",
"ReplaceRule",
"SoundEventTransformation",
"TransformConfig",
"build_transform_from_rule",
"build_transformation_from_config",
"default_derivation_registry",
"get_derivation",
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
]
SoundEventTransformation = Callable[
[data.SoundEventAnnotation], data.SoundEventAnnotation
]
"""Type alias for a sound event transformation function.
A function that accepts a sound event annotation object and returns a
(potentially) modified sound event annotation object. Transformations
should generally return a copy of the annotation rather than modifying
it in place.
"""
Derivation = Callable[[str], str]
"""Type alias for a derivation function.
A function that accepts a single string (typically a tag value) and returns
a new string (the derived value).
"""
class MapValueRule(BaseConfig):
"""Configuration for mapping specific values of a source term.
This rule replaces tags matching a specific term and one of the
original values with a new tag (potentially having a different term)
containing the corresponding replacement value. Useful for standardizing
or grouping tag values.
Attributes
----------
rule_type : Literal["map_value"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tags' values
should be checked against the `value_mapping`.
value_mapping : Dict[str, str]
A dictionary mapping original string values to replacement string
values. Only tags whose value is a key in this dictionary will be
affected.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the *output*
tag. If None (default), the output tag uses the same term as the
source (`source_term_key`). If provided, the term of the affected
tag is changed to this target term upon replacement.
"""
rule_type: Literal["map_value"] = "map_value"
source_term_key: str
value_mapping: Dict[str, str]
target_term_key: Optional[str] = None
def map_value_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
mapping: Dict[str, str],
) -> data.SoundEventAnnotation:
"""Apply a value mapping transformation to an annotation's tags.
Iterates through the annotation's tags. If a tag matches the `source_term`
and its value is found in the `mapping`, it is replaced by a new tag with
the `target_term` and the mapped value. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values should be mapped.
target_term : data.Term
The term to use for the newly created tags after mapping.
mapping : Dict[str, str]
The dictionary mapping original values to new values.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags.
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term or tag.value not in mapping:
tags.append(tag)
continue
new_value = mapping[tag.value]
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class DeriveTagRule(BaseConfig):
"""Configuration for deriving a new tag from an existing tag's value.
This rule applies a specified function (`derivation_function`) to the
value of tags matching the `source_term_key`. It then adds a new tag
with the `target_term_key` and the derived value.
Attributes
----------
rule_type : Literal["derive_tag"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tag values
will be used as input to the derivation function.
derivation_function : str
The name/key identifying the derivation function to use. This can be
a key registered in the `DerivationRegistry` or, if
`import_derivation` is True, a full Python path like
`'my_module.my_submodule.my_function'`.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the new tag
that will be created with the derived value. If None (default), the
derived tag uses the same term as the source (`source_term_key`),
effectively performing an in-place value transformation.
import_derivation : bool, default=False
If True, treat `derivation_function` as a Python import path and
attempt to dynamically import it if not found in the registry.
Requires the function to be accessible in the Python environment.
keep_source : bool, default=True
If True, the original source tag (whose value was used for derivation)
is kept in the annotation's tag list alongside the newly derived tag.
If False, the original source tag is removed.
"""
rule_type: Literal["derive_tag"] = "derive_tag"
source_term_key: str
derivation_function: str
target_term_key: Optional[str] = None
import_derivation: bool = False
keep_source: bool = True
def derivation_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
derivation: Derivation,
keep_source: bool = True,
) -> data.SoundEventAnnotation:
"""Apply a derivation transformation to an annotation's tags.
Iterates through the annotation's tags. For each tag matching the
`source_term`, its value is passed to the `derivation` function.
A new tag is created with the `target_term` and the derived value,
and added to the output tag list. The original source tag is kept
or discarded based on `keep_source`. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values serve as input for the derivation.
target_term : data.Term
The term to use for the newly created derived tags.
derivation : Derivation
The function to apply to the source tag's value.
keep_source : bool, default=True
Whether to keep the original source tag in the output.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags (including derived
ones).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term:
tags.append(tag)
continue
if keep_source:
tags.append(tag)
new_value = derivation(tag.value)
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class ReplaceRule(BaseConfig):
"""Configuration for exactly replacing one specific tag with another.
This rule looks for an exact match of the `original` tag (both term and
value) and replaces it with the specified `replacement` tag.
Attributes
----------
rule_type : Literal["replace"]
Discriminator field identifying this rule type.
original : TagInfo
The exact tag to search for, defined using its value and term key.
replacement : TagInfo
The tag to substitute in place of the original tag, defined using
its value and term key.
"""
rule_type: Literal["replace"] = "replace"
original: TagInfo
replacement: TagInfo
def replace_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source: data.Tag,
target: data.Tag,
) -> data.SoundEventAnnotation:
"""Apply an exact tag replacement transformation.
Iterates through the annotation's tags. If a tag exactly matches the
`source` tag, it is replaced by the `target` tag. Other tags are kept
unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source : data.Tag
The exact tag to find and replace.
target : data.Tag
The tag to replace the source tag with.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the replaced tag (if found).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag == source:
tags.append(target)
else:
tags.append(tag)
return sound_event_annotation.model_copy(update=dict(tags=tags))
class TransformConfig(BaseConfig):
"""Configuration model for defining a sequence of transformation rules.
Attributes
----------
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
A list of transformation rules to apply. The rules are applied
sequentially in the order they appear in the list. The output of
one rule becomes the input for the next. The `rule_type` field
discriminates between the different rule models.
"""
rules: List[
Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
] = Field(
default_factory=list,
)
class DerivationRegistry(Mapping[str, Derivation]):
"""A registry for managing named derivation functions.
Derivation functions are callables that take a string value and return
a transformed string value, used by `DeriveTagRule`. This registry
allows functions to be registered with a key and retrieved later.
"""
def __init__(self):
"""Initialize an empty DerivationRegistry."""
self._derivations: Dict[str, Derivation] = {}
def __getitem__(self, key: str) -> Derivation:
"""Retrieve a derivation function by key."""
return self._derivations[key]
def __len__(self) -> int:
"""Return the number of registered derivation functions."""
return len(self._derivations)
def __iter__(self):
"""Return an iterator over the keys of registered functions."""
return iter(self._derivations)
def register(self, key: str, derivation: Derivation) -> None:
"""Register a derivation function with a unique key.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
if key in self._derivations:
raise KeyError(
f"A derivation with the provided key {key} already exists"
)
self._derivations[key] = derivation
def get_derivation(self, key: str) -> Derivation:
"""Retrieve a derivation function by its registered key.
Parameters
----------
key : str
The key of the derivation function to retrieve.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If no derivation function with the specified key is registered.
"""
try:
return self._derivations[key]
except KeyError as err:
raise KeyError(
f"No derivation with key {key} is registered."
) from err
def get_keys(self) -> List[str]:
"""Get a list of all registered derivation function keys.
Returns
-------
List[str]
The keys of all registered functions.
"""
return list(self._derivations.keys())
def get_derivations(self) -> List[Derivation]:
"""Get a list of all registered derivation functions.
Returns
-------
List[Derivation]
The registered derivation function objects.
"""
return list(self._derivations.values())
default_derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key
in `DeriveTagRule` configuration.
"""
def get_derivation(
key: str,
import_derivation: bool = False,
registry: Optional[DerivationRegistry] = None,
):
"""Retrieve a derivation function by key, optionally importing it.
First attempts to find the function in the provided `registry`.
If not found and `import_derivation` is True, attempts to dynamically
import the function using the `key` as a full Python path
(e.g., 'my_module.submodule.my_func').
Parameters
----------
key : str
The key or Python path of the derivation function.
import_derivation : bool, default=False
If True, attempt dynamic import if key is not in the registry.
registry : DerivationRegistry, optional
The registry instance to check first. Defaults to the global
`derivation_registry`.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If the key is not found in the registry and either
`import_derivation` is False or the dynamic import fails.
ImportError
If dynamic import fails specifically due to module not found.
AttributeError
If dynamic import fails because the function name isn't in the module.
"""
registry = registry or default_derivation_registry
if not import_derivation or key in registry:
return registry.get_derivation(key)
try:
module_path, func_name = key.rsplit(".", 1)
module = importlib.import_module(module_path)
func = getattr(module, func_name)
return func
except ImportError as err:
raise KeyError(
f"Unable to load derivation '{key}'. Check the path and ensure "
"it points to a valid callable function in an importable module."
) from err
TranformationRule = Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
def build_transform_from_rule(
rule: TranformationRule,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
Selects the appropriate transformation logic based on the rule's
`rule_type`, fetches necessary terms and derivation functions, and
returns a partially applied function ready to transform an annotation.
Parameters
----------
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
The configuration object for a single transformation rule.
registry : DerivationRegistry, optional
The derivation registry to use for `DeriveTagRule`. Defaults to the
global `derivation_registry`.
Returns
-------
SoundEventTransformation
A callable that applies the specified rule to a SoundEventAnnotation.
Raises
------
KeyError
If required term keys or derivation keys are not found.
ValueError
If the rule has an unknown `rule_type`.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails.
"""
if rule.rule_type == "replace":
source = get_tag_from_info(
rule.original,
term_registry=term_registry,
)
target = get_tag_from_info(
rule.replacement,
term_registry=term_registry,
)
return partial(replace_tag_transform, source=source, target=target)
if rule.rule_type == "derive_tag":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
derivation = get_derivation(
key=rule.derivation_function,
import_derivation=rule.import_derivation,
registry=derivation_registry,
)
return partial(
derivation_tag_transform,
source_term=source_term,
target_term=target_term,
derivation=derivation,
keep_source=rule.keep_source,
)
if rule.rule_type == "map_value":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
return partial(
map_value_transform,
source_term=source_term,
target_term=target_term,
mapping=rule.value_mapping,
)
# Handle unknown rule type
valid_options = ["replace", "derive_tag", "map_value"]
# Should be caught by Pydantic validation, but good practice
raise ValueError(
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
f"Valid options are: {valid_options}"
)
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
Creates a sequence of individual transformation functions based on the
rules defined in the configuration. Returns a single function that
applies these transformations sequentially to an annotation.
Parameters
----------
config : TransformConfig
The configuration object containing the list of transformation rules.
derivation_reg : DerivationRegistry, optional
The derivation registry to use when building `DeriveTagRule`
transformations. Defaults to the global `derivation_registry`.
Returns
-------
SoundEventTransformation
A single function that applies all configured transformations in order.
"""
transforms = [
build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
for rule in config.rules
]
return partial(apply_sequence_of_transforms, transforms=transforms)
def apply_sequence_of_transforms(
sound_event_annotation: data.SoundEventAnnotation,
transforms: list[SoundEventTransformation],
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
def load_transformation_config(
path: data.PathLike, field: Optional[str] = None
) -> TransformConfig:
"""Load the transformation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
TransformConfig
The loaded and validated transformation configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
"""
return load_config(path=path, schema=TransformConfig, field=field)
def load_transformation_from_config(
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
This is a convenience function that combines loading the configuration
and building the final callable transformation function that applies
all rules sequentially.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
SoundEventTransformation
The final composite transformation function ready to be used.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
KeyError
If required term keys or derivation keys specified in the config
are not found during the build process.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function specified in the config
fails.
"""
config = load_transformation_config(path=path, field=field)
return build_transformation_from_config(
config,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
def register_derivation(
key: str,
derivation: Derivation,
derivation_registry: Optional[DerivationRegistry] = None,
) -> None:
"""Register a new derivation function in the global registry.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
derivation_registry : DerivationRegistry, optional
The registry instance to register the derivation function with.
Defaults to the global `derivation_registry`.
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
derivation_registry = derivation_registry or default_derivation_registry
derivation_registry.register(key, derivation)

View File

@ -33,10 +33,6 @@ from batdetect2.train.losses import (
SizeLossConfig,
build_loss,
)
from batdetect2.train.preprocess import (
generate_train_example,
preprocess_annotations,
)
from batdetect2.train.train import (
build_train_dataset,
build_train_loader,
@ -74,14 +70,12 @@ __all__ = [
"build_trainer",
"build_val_dataset",
"build_val_loader",
"generate_train_example",
"load_full_training_config",
"load_label_config",
"load_train_config",
"mask_frequency",
"mask_time",
"mix_audio",
"preprocess_annotations",
"scale_volume",
"select_subclip",
"train",

View File

@ -44,7 +44,7 @@ AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
name: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
@ -140,7 +140,7 @@ def combine_clip_annotations(
class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
augmentation_type: Literal["add_echo"] = "add_echo"
name: Literal["add_echo"] = "add_echo"
probability: float = 0.2
max_delay: float = 0.005
min_weight: float = 0.0
@ -187,7 +187,7 @@ def add_echo(
class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
augmentation_type: Literal["scale_volume"] = "scale_volume"
name: Literal["scale_volume"] = "scale_volume"
probability: float = 0.2
min_scaling: float = 0.0
max_scaling: float = 2.0
@ -214,7 +214,7 @@ def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
class WarpAugmentationConfig(BaseConfig):
augmentation_type: Literal["warp"] = "warp"
name: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04
@ -296,7 +296,7 @@ def warp_spectrogram(
class TimeMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_time"] = "mask_time"
name: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05
max_masks: int = 3
@ -353,7 +353,7 @@ def mask_time(
class FrequencyMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_freq"] = "mask_freq"
name: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
@ -414,7 +414,7 @@ AudioAugmentationConfig = Annotated[
MixAugmentationConfig,
EchoAugmentationConfig,
],
Field(discriminator="augmentation_type"),
Field(discriminator="name"),
]
@ -425,7 +425,7 @@ SpectrogramAugmentationConfig = Annotated[
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="augmentation_type"),
Field(discriminator="name"),
]
AugmentationConfig = Annotated[
@ -437,7 +437,7 @@ AugmentationConfig = Annotated[
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="augmentation_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of individual augmentation config."""
@ -485,7 +485,7 @@ def build_augmentation_from_config(
audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]:
"""Factory function to build a single augmentation from its config."""
if config.augmentation_type == "mix_audio":
if config.name == "mix_audio":
if audio_source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
@ -500,31 +500,31 @@ def build_augmentation_from_config(
max_weight=config.max_weight,
)
if config.augmentation_type == "add_echo":
if config.name == "add_echo":
return AddEcho(
max_delay=int(config.max_delay * samplerate),
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.augmentation_type == "scale_volume":
if config.name == "scale_volume":
return ScaleVolume(
max_scaling=config.max_scaling,
min_scaling=config.min_scaling,
)
if config.augmentation_type == "warp":
if config.name == "warp":
return WarpSpectrogram(
delta=config.delta,
)
if config.augmentation_type == "mask_time":
if config.name == "mask_time":
return MaskTime(
max_perc=config.max_perc,
max_masks=config.max_masks,
)
if config.augmentation_type == "mask_freq":
if config.name == "mask_freq":
return MaskFrequency(
max_perc=config.max_perc,
max_masks=config.max_masks,

View File

@ -6,7 +6,6 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig):
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)

View File

@ -1,9 +1,14 @@
from typing import Optional, Tuple
import lightning as L
import torch
from soundevent.data import PathLike
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import Model
from batdetect2.models import Model, build_model
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample
__all__ = [
@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
model: Model,
loss: torch.nn.Module,
config: FullTrainingConfig,
learning_rate: float = 0.001,
t_max: int = 100,
model: Optional[Model] = None,
loss: Optional[torch.nn.Module] = None,
):
super().__init__()
self.save_hyperparameters(logger=False)
self.config = config
self.learning_rate = learning_rate
self.t_max = t_max
if loss is None:
loss = build_loss(self.config.train.loss)
if model is None:
model = build_model(self.config)
self.loss = loss
self.model = model
self.save_hyperparameters(logger=False)
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.model(spec)
def training_step(self, batch: TrainExample):
outputs = self.model.detector(batch.spec)
@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule):
optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
return [optimizer], [scheduler]
def load_model_from_checkpoint(
path: PathLike,
) -> Tuple[Model, FullTrainingConfig]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config

View File

@ -5,10 +5,11 @@ import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "logs"
DEFAULT_LOGS_DIR: str = "outputs"
class DVCLiveConfig(BaseConfig):
@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig):
class TensorBoardLoggerConfig(BaseConfig):
logger_type: Literal["tensorboard"] = "tensorboard"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "default"
name: Optional[str] = "logs"
version: Optional[str] = None
log_graph: bool = False
@ -57,7 +58,10 @@ LoggerConfig = Annotated[
]
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
try:
from dvclive.lightning import DVCLiveLogger # type: ignore
except ImportError as error:
@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
) from error
return DVCLiveLogger(
dir=config.dir,
dir=log_dir if log_dir is not None else config.dir,
run_name=config.run_name,
prefix=config.prefix,
log_model=config.log_model,
@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
)
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
def create_csv_logger(
config: CSVLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
from lightning.pytorch.loggers import CSVLogger
return CSVLogger(
save_dir=config.save_dir,
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name,
version=config.version,
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
)
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
def create_tensorboard_logger(
config: TensorBoardLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger
return TensorBoardLogger(
save_dir=config.save_dir,
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name,
version=config.version,
log_graph=config.log_graph,
)
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
def create_mlflow_logger(
config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error:
@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
return MLFlowLogger(
experiment_name=config.experiment_name,
run_name=config.run_name,
save_dir=config.save_dir,
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
tracking_uri=config.tracking_uri,
tags=config.tags,
log_model=config.log_model,
@ -126,7 +139,10 @@ LOGGER_FACTORY = {
}
def build_logger(config: LoggerConfig) -> Logger:
def build_logger(
config: LoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config)
return creation_func(config, log_dir=log_dir)
def get_image_plotter(logger: Logger):

View File

@ -1,243 +0,0 @@
"""Preprocesses datasets for BatDetect2 model training."""
import os
from pathlib import Path
from typing import Callable, List, Optional, Sequence, TypedDict
import numpy as np
import torch
import torch.utils.data
from loguru import logger
from pydantic import Field
from soundevent import data
from tqdm import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.typing.train import PreprocessedExample
__all__ = [
"preprocess_annotations",
"generate_train_example",
"preprocess_dataset",
"TrainPreprocessConfig",
"load_train_preprocessing_config",
"save_preprocessed_example",
"load_preprocessed_example",
]
FilenameFn = Callable[[data.ClipAnnotation], str]
"""Type alias for a function that generates an output filename."""
class TrainPreprocessConfig(BaseConfig):
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def load_train_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TrainPreprocessConfig:
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
def preprocess_dataset(
dataset: Dataset,
config: TrainPreprocessConfig,
output: Path,
max_workers: Optional[int] = None,
) -> None:
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
labeller = build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.labels,
)
audio_loader = build_audio_loader(config=config.preprocess.audio)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
preprocess_annotations(
dataset,
output_dir=output,
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
max_workers=max_workers,
)
class Example(TypedDict):
audio: torch.Tensor
spectrogram: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def generate_train_example(
clip_annotation: data.ClipAnnotation,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
) -> PreprocessedExample:
"""Generate a complete training example for one annotation."""
wave = torch.tensor(
audio_loader.load_clip(clip_annotation.clip)
).unsqueeze(0)
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
heatmaps = labeller(clip_annotation, spectrogram)
return PreprocessedExample(
audio=wave,
spectrogram=spectrogram,
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
)
class PreprocessingDataset(torch.utils.data.Dataset):
def __init__(
self,
clips: Dataset,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn,
output_dir: Path,
force: bool = False,
):
self.clips = clips
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.labeller = labeller
self.filename_fn = filename_fn
self.output_dir = output_dir
self.force = force
def __getitem__(self, idx) -> int:
clip_annotation = self.clips[idx]
filename = self.filename_fn(clip_annotation)
path = self.output_dir / filename
if path.exists() and not self.force:
return idx
if not path.parent.exists():
path.parent.mkdir()
example = generate_train_example(
clip_annotation,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
labeller=self.labeller,
)
save_preprocessed_example(example, clip_annotation, path)
return idx
def __len__(self) -> int:
return len(self.clips)
def save_preprocessed_example(
example: PreprocessedExample,
clip_annotation: data.ClipAnnotation,
path: data.PathLike,
) -> None:
np.savez_compressed(
path,
audio=example.audio.numpy(),
spectrogram=example.spectrogram.numpy(),
detection_heatmap=example.detection_heatmap.numpy(),
class_heatmap=example.class_heatmap.numpy(),
size_heatmap=example.size_heatmap.numpy(),
clip_annotation=clip_annotation,
)
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}"
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: data.PathLike,
preprocessor: PreprocessorProtocol,
audio_loader: AudioLoader,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
output_dir = Path(output_dir)
if not output_dir.is_dir():
logger.info(
"Creating output directory: {output_dir}", output_dir=output_dir
)
output_dir.mkdir(parents=True)
logger.info(
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
num_annotations=len(clip_annotations),
max_workers=max_workers or "all available",
)
if max_workers is None:
max_workers = os.cpu_count() or 0
dataset = PreprocessingDataset(
clips=list(clip_annotations),
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
output_dir=Path(output_dir),
filename_fn=filename_fn,
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=max_workers,
prefetch_factor=16,
)
for _ in tqdm(loader, total=len(dataset)):
pass

View File

@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import Model, build_model
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,
@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
from batdetect2.typing import (
PreprocessorProtocol,
TargetProtocol,
@ -54,19 +53,21 @@ def train(
model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None,
):
config = config or FullTrainingConfig()
model = build_model(config=config)
targets = build_targets(config.targets)
trainer = build_trainer(config, targets=model.targets)
preprocessor = build_preprocessor(config.preprocess)
audio_loader = build_audio_loader(config=config.preprocess.audio)
labeller = build_clip_labeler(
model.targets,
min_freq=model.preprocessor.min_freq,
max_freq=model.preprocessor.max_freq,
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.train.labels,
)
@ -74,7 +75,7 @@ def train(
train_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=build_preprocessor(config.preprocess),
preprocessor=preprocessor,
config=config.train,
num_workers=train_workers,
)
@ -84,7 +85,7 @@ def train(
val_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=build_preprocessor(config.preprocess),
preprocessor=preprocessor,
config=config.train,
num_workers=val_workers,
)
@ -97,9 +98,15 @@ def train(
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(
model,
config,
batches_per_epoch=len(train_dataloader),
t_max=config.train.t_max * len(train_dataloader),
)
trainer = build_trainer(
config,
targets=targets,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
)
logger.info("Starting main training loop...")
@ -112,16 +119,14 @@ def train(
def build_training_module(
model: Model,
config: FullTrainingConfig,
batches_per_epoch: int,
config: Optional[FullTrainingConfig] = None,
t_max: int = 200,
) -> TrainingModule:
loss = build_loss(config=config.train.loss)
config = config or FullTrainingConfig()
return TrainingModule(
model=model,
loss=loss,
config=config,
learning_rate=config.train.learning_rate,
t_max=config.train.t_max * batches_per_epoch,
t_max=t_max,
)
@ -129,10 +134,14 @@ def build_trainer_callbacks(
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: EvaluationConfig,
checkpoint_dir: Optional[data.PathLike] = None,
) -> List[Callback]:
if checkpoint_dir is None:
checkpoint_dir = "outputs/checkpoints"
return [
ModelCheckpoint(
dirpath="outputs/checkpoints",
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="total_loss/val",
),
@ -153,15 +162,22 @@ def build_trainer_callbacks(
def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None,
) -> Trainer:
trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(conf.train.logger)
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
train_logger.log_hyperparams(conf.model_dump(mode="json"))
train_logger.log_hyperparams(
conf.model_dump(
mode="json",
exclude_none=True,
)
)
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
@ -170,6 +186,7 @@ def build_trainer(
targets,
config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir,
),
)

View File

@ -67,8 +67,7 @@ class TargetProtocol(Protocol):
This protocol outlines the standard attributes and methods for an object
that encapsulates the complete, configured process for handling sound event
annotations (both tags and geometry). It defines how to:
- Filter relevant annotations.
- Transform annotation tags.
- Select relevant annotations.
- Encode an annotation into a specific target class name.
- Decode a class name back into representative tags.
- Extract a target reference position from an annotation's geometry (ROI).
@ -121,26 +120,6 @@ class TargetProtocol(Protocol):
"""
...
def transform(
self,
sound_event: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
"""Apply tag transformations to an annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation whose tags should be transformed.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags. Implementations
should return the original annotation object if no transformations
were configured.
"""
...
def encode_class(
self,
sound_event: data.SoundEventAnnotation,
@ -248,3 +227,70 @@ class TargetProtocol(Protocol):
if reconstruction fails based on the configured position type.
"""
...
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the `encode` and `decode` methods required for converting a
`soundevent.data.SoundEvent` into a target representation (a reference
position and a size vector) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions in the `Size` array
returned by `encode` and expected by `decode`.
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent's geometry into a position and size.
Parameters
----------
sound_event : data.SoundEvent
The input sound event, which must have a geometry attribute.
Returns
-------
Tuple[Position, Size]
A tuple containing:
- The reference position as (time, frequency) coordinates.
- A NumPy array with the calculated size dimensions.
Raises
------
ValueError
If the sound event does not have a geometry.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and size
dimensions and reconstructs a geometric representation.
Parameters
----------
position : Position
The reference position (time, frequency).
size : Size
NumPy array containing the size dimensions, matching the order
and meaning specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry, typically a `BoundingBox`.
Raises
------
ValueError
If the `size` array has an unexpected shape or if reconstruction
fails.
"""
...

View File

@ -15,13 +15,10 @@ from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import (
TargetConfig,
TermRegistry,
build_targets,
call_type,
)
from batdetect2.targets.classes import ClassesConfig, TargetClass
from batdetect2.targets.filtering import FilterConfig, FilterRule
from batdetect2.targets.terms import TagInfo
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.train.clips import build_clipper
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import (
@ -355,18 +352,6 @@ def create_annotation_project():
return factory
@pytest.fixture
def sample_term_registry() -> TermRegistry:
"""Fixture for a sample TermRegistry."""
registry = TermRegistry()
registry.add_custom_term("class")
registry.add_custom_term("order")
registry.add_custom_term("species")
registry.add_custom_term("call_type")
registry.add_custom_term("quality")
return registry
@pytest.fixture
def sample_preprocessor() -> PreprocessorProtocol:
return build_preprocessor()
@ -378,56 +363,45 @@ def sample_audio_loader() -> AudioLoader:
@pytest.fixture
def bat_tag() -> TagInfo:
return TagInfo(key="class", value="bat")
def bat_tag() -> data.Tag:
return data.Tag(key="class", value="bat")
@pytest.fixture
def noise_tag() -> TagInfo:
return TagInfo(key="class", value="noise")
def noise_tag() -> data.Tag:
return data.Tag(key="class", value="noise")
@pytest.fixture
def myomyo_tag() -> TagInfo:
return TagInfo(key="species", value="Myotis myotis")
def myomyo_tag() -> data.Tag:
return data.Tag(key="species", value="Myotis myotis")
@pytest.fixture
def pippip_tag() -> TagInfo:
return TagInfo(key="species", value="Pipistrellus pipistrellus")
def pippip_tag() -> data.Tag:
return data.Tag(key="species", value="Pipistrellus pipistrellus")
@pytest.fixture
def sample_target_config(
sample_term_registry: TermRegistry,
bat_tag: TagInfo,
noise_tag: TagInfo,
myomyo_tag: TagInfo,
pippip_tag: TagInfo,
bat_tag: data.Tag,
myomyo_tag: data.Tag,
pippip_tag: data.Tag,
) -> TargetConfig:
return TargetConfig(
filtering=FilterConfig(
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
),
classes=ClassesConfig(
classes=[
TargetClass(name="pippip", tags=[pippip_tag]),
TargetClass(name="myomyo", tags=[myomyo_tag]),
detection_target=TargetClassConfig(name="bat", tags=[bat_tag]),
classification_targets=[
TargetClassConfig(name="pippip", tags=[pippip_tag]),
TargetClassConfig(name="myomyo", tags=[myomyo_tag]),
],
generic_class=[bat_tag],
),
)
@pytest.fixture
def sample_targets(
sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
) -> TargetProtocol:
return build_targets(
sample_target_config,
term_registry=sample_term_registry,
)
return build_targets(sample_target_config)
@pytest.fixture
@ -443,10 +417,8 @@ def sample_labeller(
@pytest.fixture
def sample_clipper(
sample_preprocessor: PreprocessorProtocol,
) -> ClipperProtocol:
return build_clipper(preprocessor=sample_preprocessor)
def sample_clipper() -> ClipperProtocol:
return build_clipper()
@pytest.fixture

View File

@ -0,0 +1,516 @@
import textwrap
import pytest
import yaml
from pydantic import TypeAdapter
from soundevent import data
from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
)
def build_condition_from_str(content):
content = textwrap.dedent(content)
content = yaml.safe_load(content)
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
return build_sound_event_condition(config)
def test_has_tag(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_tag
tag:
key: species
value: Myotis myotis
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert not condition(sound_event_annotation)
def test_has_all_tags(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_all_tags
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
data.Tag(key="sex", value="Female"), # type: ignore
],
)
assert condition(sound_event_annotation)
def test_has_any_tags(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_any_tag
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
data.Tag(key="event", value="Social"), # type: ignore
],
)
assert not condition(sound_event_annotation)
def test_not(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: not
condition:
name: has_tag
tag:
key: species
value: Myotis myotis
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"), # type: ignore
data.Tag(key="event", value="Echolocation"), # type: ignore
],
)
assert not condition(sound_event_annotation)
def test_duration(recording: data.Recording):
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
),
)
se2 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
),
)
se3 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
),
)
condition = build_condition_from_str("""
name: duration
operator: lt
seconds: 2
""")
assert condition(se1)
assert not condition(se2)
assert not condition(se3)
condition = build_condition_from_str("""
name: duration
operator: lte
seconds: 2
""")
assert condition(se1)
assert condition(se2)
assert not condition(se3)
condition = build_condition_from_str("""
name: duration
operator: gt
seconds: 2
""")
assert not condition(se1)
assert not condition(se2)
assert condition(se3)
condition = build_condition_from_str("""
name: duration
operator: gte
seconds: 2
""")
assert not condition(se1)
assert condition(se2)
assert condition(se3)
condition = build_condition_from_str("""
name: duration
operator: eq
seconds: 2
""")
assert not condition(se1)
assert condition(se2)
assert not condition(se3)
def test_frequency(recording: data.Recording):
se12 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
),
)
se13 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
),
)
se14 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
),
)
se24 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
),
)
se34 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
),
)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: lt
hertz: 300
""")
assert condition(se12)
assert not condition(se13)
assert not condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: lte
hertz: 300
""")
assert condition(se12)
assert condition(se13)
assert not condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: gt
hertz: 300
""")
assert not condition(se12)
assert not condition(se13)
assert condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: gte
hertz: 300
""")
assert not condition(se12)
assert condition(se13)
assert condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: eq
hertz: 300
""")
assert not condition(se12)
assert condition(se13)
assert not condition(se14)
# LOW
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: lt
hertz: 200
""")
assert condition(se14)
assert not condition(se24)
assert not condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: lte
hertz: 200
""")
assert condition(se14)
assert condition(se24)
assert not condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: gt
hertz: 200
""")
assert not condition(se14)
assert not condition(se24)
assert condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: gte
hertz: 200
""")
assert not condition(se14)
assert condition(se24)
assert condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
assert not condition(se14)
assert condition(se24)
assert not condition(se34)
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 3]),
recording=recording,
)
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeStamp(coordinates=3),
recording=recording,
)
)
assert not condition(se)
def test_has_tags_fails_if_empty():
with pytest.raises(ValueError):
build_condition_from_str("""
name: has_tags
tags: []
""")
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_duration_is_false_if_no_geometry(recording: data.Recording):
condition = build_condition_from_str("""
name: duration
operator: eq
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(geometry=None, recording=recording)
)
assert not condition(se)
def test_all_of(recording: data.Recording):
condition = build_condition_from_str("""
name: all_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert not condition(se)
def test_any_of(recording: data.Recording):
condition = build_condition_from_str("""
name: any_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
)
assert condition(se)

View File

@ -3,26 +3,15 @@ from typing import Callable
from uuid import uuid4
import pytest
from pydantic import ValidationError
from soundevent import data
from soundevent.terms import get_term
from batdetect2.targets.classes import (
DEFAULT_SPECIES_LIST,
ClassesConfig,
TargetClass,
_get_default_class_name,
_get_default_classes,
build_generic_class_tags,
TargetClassConfig,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
is_target_class,
load_classes_config,
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.terms import TagInfo
@pytest.fixture
@ -33,8 +22,8 @@ def sample_annotation(
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
data.Tag(key="species", value="Pipistrellus pipistrellus"),
data.Tag(key="quality", value="Good"),
],
)
@ -51,291 +40,71 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
return factory
def test_target_class_creation():
target_class = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
assert target_class.name == "pippip"
assert target_class.tags[0].key == "species"
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
assert target_class.match_type == "all"
def test_classes_config_creation():
target_class = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
config = ClassesConfig(classes=[target_class])
assert len(config.classes) == 1
assert config.classes[0].name == "pippip"
def test_classes_config_unique_names():
target_class1 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
name="myodau",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
)
ClassesConfig(classes=[target_class1, target_class2]) # No error
def test_classes_config_non_unique_names():
target_class1 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
)
with pytest.raises(ValidationError):
ClassesConfig(classes=[target_class1, target_class2])
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
config = load_classes_config(temp_yaml_path)
assert len(config.classes) == 1
assert config.classes[0].name == "pippip"
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: pippip
tags:
- key: species
value: Myotis daubentonii
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(ValidationError):
load_classes_config(temp_yaml_path)
def test_is_target_class_match_all(
sample_annotation: data.SoundEventAnnotation,
):
tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
}
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any(
sample_annotation: data.SoundEventAnnotation,
):
tags = {
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
data.Tag(key="quality", value="Good"), # type: ignore
}
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
assert is_target_class(sample_annotation, tags, match_all=False) is False
def test_get_class_names_from_config():
target_class1 = TargetClass(
target_class1 = TargetClassConfig(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
target_class2 = TargetClassConfig(
name="myodau",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
tags=[data.Tag(key="species", value="Myotis daubentonii")],
)
config = ClassesConfig(classes=[target_class1, target_class2])
names = get_class_names_from_config(config)
names = get_class_names_from_config([target_class1, target_class2])
assert names == ["pippip", "myodau"]
def test_build_encoder_from_config(
sample_annotation: data.SoundEventAnnotation,
):
config = ClassesConfig(
classes=[
TargetClass(
classes = [
TargetClassConfig(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
)
]
)
encoder = build_sound_event_encoder(config)
encoder = build_sound_event_encoder(classes)
result = encoder(sample_annotation)
assert result == "pippip"
config = ClassesConfig(classes=[])
encoder = build_sound_event_encoder(config)
classes = []
encoder = build_sound_event_encoder(classes)
result = encoder(sample_annotation)
assert result is None
def test_load_encoder_from_config_valid(
sample_annotation: data.SoundEventAnnotation,
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
encoder = load_encoder_from_config(temp_yaml_path)
# We cannot directly compare the function, so we test it.
result = encoder(sample_annotation) # type: ignore
assert result == "pippip"
def test_load_encoder_from_config_invalid(
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
classes:
- name: pippip
tags:
- key: invalid_key
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(KeyError):
load_encoder_from_config(temp_yaml_path)
def test_get_default_class_name():
assert _get_default_class_name("Myotis daubentonii") == "myodau"
def test_get_default_classes():
default_classes = _get_default_classes()
assert len(default_classes) == len(DEFAULT_SPECIES_LIST)
first_class = default_classes[0]
assert isinstance(first_class, TargetClass)
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
assert first_class.tags[0].key == "class"
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
def test_build_decoder_from_config():
config = ClassesConfig(
classes=[
TargetClass(
classes = [
TargetClassConfig(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
output_tags=[TagInfo(key="call_type", value="Echolocation")],
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
assign_tags=[data.Tag(key="call_type", value="Echolocation")],
)
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(config)
]
decoder = build_sound_event_decoder(classes)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("event")
assert tags[0].value == "Echolocation"
# Test when output_tags is None, should fall back to tags
config = ClassesConfig(
classes=[
TargetClass(
classes = [
TargetClassConfig(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
)
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_sound_event_decoder(config)
]
decoder = build_sound_event_decoder(classes)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("species")
assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
decoder = build_sound_event_decoder(classes, raise_on_unmapped=True)
with pytest.raises(ValueError):
decoder("unknown_class")
# Test raise_on_unmapped=False
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
decoder = build_sound_event_decoder(classes, raise_on_unmapped=False)
tags = decoder("unknown_class")
assert len(tags) == 0
def test_load_decoder_from_config_valid(
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
output_tags:
- key: call_type
value: Echolocation
generic_class:
- key: order
value: Chiroptera
"""
temp_yaml_path = create_temp_yaml(yaml_content)
decoder = load_decoder_from_config(
temp_yaml_path,
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == get_term("call_type")
assert tags[0].value == "Echolocation"
def test_build_generic_class_tags_from_config():
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
)
],
generic_class=[
TagInfo(key="order", value="Chiroptera"),
TagInfo(key="call_type", value="Echolocation"),
],
)
generic_tags = build_generic_class_tags(config)
assert len(generic_tags) == 2
assert generic_tags[0].term == get_term("order")
assert generic_tags[0].value == "Chiroptera"
assert generic_tags[1].term == get_term("call_type")
assert generic_tags[1].value == "Echolocation"

View File

@ -1,210 +0,0 @@
from pathlib import Path
from typing import Callable, List, Set
import pytest
from soundevent import data
from batdetect2.targets import build_targets
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
build_filter_from_rule,
build_sound_event_filter,
contains_tags,
does_not_have_tags,
equal_tags,
has_any_tag,
load_filter_config,
load_filter_from_config,
)
from batdetect2.targets.terms import TagInfo, generic_class
@pytest.fixture
def create_annotation(
sound_event: data.SoundEvent,
) -> Callable[[List[str]], data.SoundEventAnnotation]:
"""Helper function to create a SoundEventAnnotation with given tags."""
def factory(tags: List[str]) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(
term=generic_class,
value=tag,
)
for tag in tags
],
)
return factory
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
"""Helper function to create a set of data.Tag objects from a list of strings."""
return {
data.Tag(
term=generic_class,
value=tag,
)
for tag in tags
}
def test_has_any_tag(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag3"])
assert has_any_tag(annotation, tags) is True
annotation = create_annotation(["tag2", "tag4"])
tags = create_tag_set(["tag1", "tag3"])
assert has_any_tag(annotation, tags) is False
def test_contains_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2", "tag3"])
tags = create_tag_set(["tag1", "tag2"])
assert contains_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag2", "tag3"])
assert contains_tags(annotation, tags) is False
def test_does_not_have_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag3", "tag4"])
assert does_not_have_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag3"])
assert does_not_have_tags(annotation, tags) is False
def test_equal_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag2"])
assert equal_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2", "tag3"])
tags = create_tag_set(["tag1", "tag2"])
assert equal_tags(annotation, tags) is False
def test_build_filter_from_rule():
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_any)
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_all)
rule_exclude = FilterRule(
match_type="exclude", tags=[TagInfo(value="tag1")]
)
build_filter_from_rule(rule_exclude)
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_equal)
with pytest.raises(ValueError):
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
build_filter_from_rule(
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
)
def test_build_filter_from_config(create_annotation):
config = FilterConfig(
rules=[
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
]
)
filter_from_config = build_sound_event_filter(config)
annotation_pass = create_annotation(["tag1", "tag2"])
assert filter_from_config(annotation_pass)
annotation_fail = create_annotation(["tag1"])
assert not filter_from_config(annotation_fail)
def test_load_filter_config(tmp_path: Path):
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag1
"""
)
config = load_filter_config(test_config_path)
assert isinstance(config, FilterConfig)
assert len(config.rules) == 1
rule = config.rules[0]
assert rule.match_type == "any"
assert len(rule.tags) == 1
assert rule.tags[0].value == "tag1"
def test_load_filter_from_config(tmp_path: Path, create_annotation):
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag1
"""
)
filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation)
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag2
"""
)
filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation) is False
def test_default_filtering_over_example_dataset(
example_annotations: List[data.ClipAnnotation],
):
targets = build_targets()
clip1 = example_annotations[0]
clip2 = example_annotations[1]
clip3 = example_annotations[2]
assert (
sum(
[targets.filter(sound_event) for sound_event in clip1.sound_events]
)
== 9
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip2.sound_events]
)
== 15
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip3.sound_events]
)
== 20
)

View File

@ -8,6 +8,11 @@ from batdetect2.preprocess import (
build_preprocessor,
)
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.preprocess.spectrogram import (
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
SpectrogramConfig,
)
from batdetect2.targets.rois import (
DEFAULT_ANCHOR,
DEFAULT_FREQUENCY_SCALE,
@ -548,14 +553,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
# Instantiate the mapper.
preprocessor = build_preprocessor(
PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": False,
}
}
)
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
)
audio_loader = build_audio_loader()
mapper = PeakEnergyBBoxMapper(
@ -597,14 +595,13 @@ def test_build_roi_mapper_for_anchor_bbox():
def test_build_roi_mapper_for_peak_energy_bbox():
# Given
preproc_config = PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": True,
"scale": "dB",
}
}
preproc_config = PreprocessingConfig(
spectrogram=SpectrogramConfig(
transforms=[
ScaleAmplitudeConfig(scale="db"),
SpectralMeanSubstractionConfig(),
]
),
)
config = PeakEnergyBBoxMapperConfig(
loading_buffer=0.99,

View File

@ -1,27 +1,33 @@
from collections.abc import Callable
from pathlib import Path
from soundevent import data
from soundevent import data, terms
from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
def test_can_override_default_roi_mapper_per_class(
create_temp_yaml: Callable[..., Path],
recording: data.Recording,
sample_term_registry,
):
yaml_content = """
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
detection_target:
name: bat
match_if:
name: has_tag
tag:
key: order
value: Chiroptera
assign_tags:
- key: order
value: Chiroptera
classification_targets:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
@ -29,18 +35,21 @@ def test_can_override_default_roi_mapper_per_class(
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
roi:
name: anchor_bbox
anchor: bottom-left
"""
config_path = create_temp_yaml(yaml_content)
config = load_target_config(config_path)
targets = build_targets(config, term_registry=sample_term_registry)
targets = build_targets(config)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
species = terms.get_term("species")
assert species is not None
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
@ -62,19 +71,26 @@ def test_can_override_default_roi_mapper_per_class(
# TODO: rename this test function
def test_roi_is_recovered_roundtrip_even_with_overriders(
create_temp_yaml,
sample_term_registry,
recording,
):
yaml_content = """
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
detection_target:
name: bat
match_if:
name: has_tag
tag:
key: order
value: Chiroptera
assign_tags:
- key: order
value: Chiroptera
classification_targets:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
@ -82,18 +98,20 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
roi:
name: anchor_bbox
anchor: bottom-left
"""
config_path = create_temp_yaml(yaml_content)
config = load_target_config(config_path)
targets = build_targets(config, term_registry=sample_term_registry)
targets = build_targets(config)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
species = terms.get_term("species")
assert species is not None
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],

View File

@ -1,179 +0,0 @@
import pytest
import yaml
from soundevent import data
from batdetect2.targets import terms
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
load_terms_from_config,
)
def test_term_registry_initialization():
registry = TermRegistry()
assert registry._terms == {}
initial_terms = {
"test_term": data.Term(name="test", label="Test", definition="test")
}
registry = TermRegistry(terms=initial_terms)
assert registry._terms == initial_terms
def test_term_registry_add_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
assert registry._terms["test_key"] == term
def test_term_registry_get_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
retrieved_term = registry.get_term("test_key")
assert retrieved_term == term
def test_term_registry_add_custom_term():
registry = TermRegistry()
term = registry.add_custom_term(
"custom_key", name="custom", label="Custom", definition="A custom term"
)
assert registry._terms["custom_key"] == term
assert term.name == "custom"
assert term.label == "Custom"
assert term.definition == "A custom term"
def test_term_registry_add_duplicate_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
with pytest.raises(KeyError):
registry.add_term("test_key", term)
def test_term_registry_get_term_not_found():
registry = TermRegistry()
with pytest.raises(KeyError):
registry.get_term("non_existent_key")
def test_term_registry_get_keys():
registry = TermRegistry()
term1 = data.Term(name="test1", label="Test1", definition="test")
term2 = data.Term(name="test2", label="Test2", definition="test")
registry.add_term("key1", term1)
registry.add_term("key2", term2)
keys = registry.get_keys()
assert set(keys) == {"key1", "key2"}
def test_get_term_from_key():
term = terms.get_term_from_key("event")
assert term == terms.call_type
custom_registry = TermRegistry()
custom_term = data.Term(name="custom", label="Custom", definition="test")
custom_registry.add_term("custom_key", custom_term)
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
assert term == custom_term
def test_get_term_keys():
keys = terms.get_term_keys()
assert "event" in keys
assert "individual" in keys
assert terms.GENERIC_CLASS_KEY in keys
custom_registry = TermRegistry()
custom_term = data.Term(name="custom", label="Custom", definition="test")
custom_registry.add_term("custom_key", custom_term)
keys = terms.get_term_keys(term_registry=custom_registry)
assert "custom_key" in keys
def test_tag_info_and_get_tag_from_info():
tag_info = TagInfo(value="Myotis myotis", key="event")
tag = terms.get_tag_from_info(tag_info)
assert tag.value == "Myotis myotis"
assert tag.term == terms.call_type
def test_get_tag_from_info_key_not_found():
tag_info = TagInfo(value="test", key="non_existent_key")
with pytest.raises(KeyError):
terms.get_tag_from_info(tag_info)
def test_load_terms_from_config(tmp_path):
term_registry = TermRegistry()
config_data = {
"terms": [
{
"key": "species",
"name": "dwc:scientificName",
"label": "Scientific Name",
},
{
"key": "my_custom_term",
"name": "soundevent:custom_term",
"definition": "Describes a specific project attribute",
},
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loaded_terms = load_terms_from_config(
config_file,
term_registry=term_registry,
)
assert "species" in loaded_terms
assert "my_custom_term" in loaded_terms
assert loaded_terms["species"].name == "dwc:scientificName"
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
def test_load_terms_from_config_file_not_found():
with pytest.raises(FileNotFoundError):
load_terms_from_config("non_existent_file.yaml")
def test_load_terms_from_config_validation_error(tmp_path):
config_data = {
"terms": [
{
"key": "species",
"uri": "dwc:scientificName",
"label": 123,
}, # Invalid label type
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
with pytest.raises(ValueError):
load_terms_from_config(config_file)
def test_load_terms_from_config_key_already_exists(tmp_path):
config_data = {
"terms": [
{
"key": "event",
"uri": "dwc:scientificName",
"label": "Scientific Name",
}, # Duplicate key
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
with pytest.raises(KeyError):
load_terms_from_config(config_file)

View File

@ -1,363 +0,0 @@
from pathlib import Path
import pytest
from soundevent import data
from batdetect2.targets import (
DeriveTagRule,
MapValueRule,
ReplaceRule,
TagInfo,
TransformConfig,
build_transformation_from_config,
)
from batdetect2.targets.terms import TermRegistry
from batdetect2.targets.transform import (
DerivationRegistry,
build_transform_from_rule,
)
@pytest.fixture
def term_registry():
return TermRegistry()
@pytest.fixture
def derivation_registry():
return DerivationRegistry()
@pytest.fixture
def term1(term_registry: TermRegistry) -> data.Term:
return term_registry.add_custom_term(key="term1")
@pytest.fixture
def term2(term_registry: TermRegistry) -> data.Term:
return term_registry.add_custom_term(key="term2")
@pytest.fixture
def annotation(
sound_event: data.SoundEvent,
term1: data.Term,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
)
def test_map_value_rule(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value2"
def test_map_value_rule_no_match(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"other_value": "value2"},
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value1"
def test_replace_rule(
annotation: data.SoundEventAnnotation,
term2: data.Term,
term_registry: TermRegistry,
):
rule = ReplaceRule(
rule_type="replace",
original=TagInfo(key="term1", value="value1"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
def test_replace_rule_no_match(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term2: data.Term,
):
rule = ReplaceRule(
rule_type="replace",
original=TagInfo(key="term1", value="wrong_value"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].key == "term1"
assert transformed_annotation.tags[0].term != term2
assert transformed_annotation.tags[0].value == "value1"
def test_build_transformation_from_config(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
config = TransformConfig(
rules=[
MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
),
ReplaceRule(
rule_type="replace",
original=TagInfo(key="term2", value="value2"),
replacement=TagInfo(key="term3", value="value3"),
),
]
)
term_registry.add_custom_term("term2")
term_registry.add_custom_term("term3")
transform = build_transformation_from_config(
config,
term_registry=term_registry,
)
transformed_annotation = transform(annotation)
assert transformed_annotation.tags[0].key == "term1"
assert transformed_annotation.tags[0].value == "value2"
def test_derive_tag_rule(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_derived"
def test_derive_tag_rule_keep_source_false(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
keep_source=False,
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 1
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1_derived"
def test_derive_tag_rule_target_term(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
term2: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
target_term_key="term2",
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term2
assert transformed_annotation.tags[1].value == "value1_derived"
def test_derive_tag_rule_import_derivation(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term1: data.Term,
tmp_path: Path,
):
# Create a dummy derivation function in a temporary file
derivation_module_path = (
tmp_path / "temp_derivation.py"
) # Changed to /tmp since /home/santiago is not writable
derivation_module_path.write_text(
"""
def my_imported_derivation(x: str) -> str:
return x + "_imported"
"""
)
# Ensure the temporary file is importable by adding its directory to sys.path
import sys
sys.path.insert(0, str(tmp_path))
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="temp_derivation.my_imported_derivation",
import_derivation=True,
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_imported"
# Clean up the temporary file and sys.path
sys.path.remove(str(tmp_path))
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="nonexistent_derivation",
)
with pytest.raises(KeyError):
build_transform_from_rule(rule, term_registry=term_registry)
def test_build_transform_from_rule_invalid_rule_type():
class InvalidRule:
rule_type = "invalid"
rule = InvalidRule() # type: ignore
with pytest.raises(ValueError):
build_transform_from_rule(rule) # type: ignore
def test_map_value_rule_target_term(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term2: data.Term,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
target_term_key="term2",
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
def test_map_value_rule_target_term_none(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term1: data.Term,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
target_term_key=None,
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value2"
def test_derive_tag_rule_target_term_none(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
target_term_key=None,
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_derived"
def test_build_transformation_from_config_empty(
annotation: data.SoundEventAnnotation,
):
config = TransformConfig(rules=[])
transform = build_transformation_from_config(config)
transformed_annotation = transform(annotation)
assert transformed_annotation == annotation

View File

@ -1,170 +0,0 @@
from collections.abc import Callable
import pytest
import torch
from soundevent import data
from batdetect2.train.augmentations import (
add_echo,
mix_audio,
)
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
def test_mix_examples(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
example1 = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
mixed = mix_audio(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
assert mixed.spectrogram.shape == example1.spectrogram.shape
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
def test_mix_examples_of_different_durations(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
duration1: float,
duration2: float,
):
recording1 = create_recording()
recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
example1 = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
mixed = mix_audio(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
assert mixed.spectrogram.shape == example1.spectrogram.shape
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
def test_add_echo(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
with_echo = add_echo(
original,
preprocessor=sample_preprocessor,
delay=0.1,
weight=0.3,
)
assert with_echo.spectrogram.shape == original.spectrogram.shape
torch.testing.assert_close(
with_echo.size_heatmap,
original.size_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.class_heatmap,
original.class_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.detection_heatmap,
original.detection_heatmap,
atol=0,
rtol=0,
)
def test_selected_random_subclip_has_the_correct_width(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
subclip = select_subclip(
original,
input_samplerate=256_000,
output_samplerate=1000,
start=0,
duration=0.512,
)
assert subclip.spectrogram.shape[1] == 512

View File

@ -1,27 +0,0 @@
from soundevent import data
from batdetect2.train import generate_train_example
from batdetect2.typing import (
AudioLoader,
ClipLabeller,
ClipperProtocol,
PreprocessorProtocol,
)
def test_default_clip_size_is_correct(
sample_clipper: ClipperProtocol,
sample_labeller: ClipLabeller,
sample_audio_loader: AudioLoader,
clip_annotation: data.ClipAnnotation,
sample_preprocessor: PreprocessorProtocol,
):
example = generate_train_example(
clip_annotation=clip_annotation,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
clip, _, _ = sample_clipper(example)
assert clip.spectrogram.shape == (1, 128, 256)

View File

@ -5,7 +5,6 @@ from soundevent import data
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig
from batdetect2.targets.terms import TagInfo
from batdetect2.train.labels import generate_heatmaps
recording = data.Recording(
@ -26,7 +25,8 @@ clip = data.Clip(
def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig,
pippip_tag: TagInfo,
pippip_tag: data.Tag,
bat_tag: data.Tag,
):
config = sample_target_config.model_copy(
update=dict(
@ -49,14 +49,14 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
coordinates=[10, 10, 20, 30],
),
),
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
tags=[pippip_tag, bat_tag],
)
],
)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation,
torch.rand([100, 100]),
torch.rand([1, 100, 100]),
min_freq=0,
max_freq=100,
targets=targets,
@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
assert size_heatmap[1, 10, 10] == 20
assert class_heatmap[pippip_index, 10, 10] == 1.0
assert class_heatmap[myomyo_index, 10, 10] == 0.0
assert detection_heatmap[10, 10] == 1.0
assert detection_heatmap[0, 10, 10] == 1.0

View File

@ -4,14 +4,16 @@ import lightning as L
import torch
from soundevent import data
from batdetect2.models import build_model
from batdetect2.train import FullTrainingConfig, TrainingModule
from batdetect2.train.train import build_training_module
from batdetect2.typing.preprocess import AudioLoader
def build_default_module():
model = build_model()
config = FullTrainingConfig()
return build_training_module(config)
return build_training_module(model, config=config)
def test_can_initialize_default_module():
@ -32,14 +34,14 @@ def test_can_save_checkpoint(
recovered = TrainingModule.load_from_checkpoint(path)
wav = torch.tensor(sample_audio_loader.load_clip(clip))
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
spec1 = module.model.preprocessor(wav)
spec2 = recovered.model.preprocessor(wav)
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
output1 = module(spec1.unsqueeze(0))
output2 = recovered(spec2.unsqueeze(0))
torch.testing.assert_close(output1, output2, rtol=0, atol=0)

View File

@ -1,230 +0,0 @@
import pytest
from soundevent import data
from soundevent.terms import get_term
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ModelOutput
from batdetect2.typing.preprocess import AudioLoader
@pytest.fixture
def build_from_config(
create_temp_yaml,
):
def build(yaml_content):
config_path = create_temp_yaml(yaml_content)
targets_config = load_target_config(config_path, field="targets")
preprocessing_config = load_preprocessing_config(
config_path,
field="preprocessing",
)
labels_config = load_label_config(config_path, field="labels")
postprocessing_config = load_postprocess_config(
config_path,
field="postprocessing",
)
targets = build_targets(targets_config)
preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler(
targets=targets,
config=labels_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
postprocessor = build_postprocessor(
preprocessor=preprocessor,
config=postprocessing_config,
)
return targets, preprocessor, labeller, postprocessor
return build
def test_encoding_decoding_roundtrip_recovers_object(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(
clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
0
),
size_preds=encoded.size_heatmap.unsqueeze(0),
class_probs=encoded.class_heatmap.unsqueeze(0),
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(
clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
0
),
size_preds=encoded.size_heatmap.unsqueeze(0),
class_probs=encoded.class_heatmap.unsqueeze(0),
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"