From b7ae526071984eabc0f1a5ff57e8b78e4153fa89 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 8 Sep 2025 17:50:25 +0100 Subject: [PATCH] Big changes in data module --- docs/source/targets/tags_and_terms.md | 2 +- example_data/config.yaml | 134 ++-- example_data/dataset.yaml | 30 + pyproject.toml | 2 +- src/batdetect2/configs.py | 2 +- src/batdetect2/data/_core.py | 61 ++ src/batdetect2/data/annotations/__init__.py | 14 +- src/batdetect2/data/conditions.py | 287 +++++++ src/batdetect2/data/datasets.py | 119 ++- src/batdetect2/data/transforms.py | 250 +++++++ src/batdetect2/models/blocks.py | 38 +- src/batdetect2/models/bottleneck.py | 2 +- src/batdetect2/models/decoder.py | 8 +- src/batdetect2/models/encoder.py | 8 +- src/batdetect2/preprocess/spectrogram.py | 12 +- src/batdetect2/targets/__init__.py | 313 ++------ src/batdetect2/targets/classes.py | 706 +++++------------- src/batdetect2/targets/filtering.py | 293 -------- src/batdetect2/targets/terms.py | 65 +- src/batdetect2/targets/transform.py | 689 ----------------- src/batdetect2/train/augmentations.py | 30 +- src/batdetect2/typing/targets.py | 23 +- tests/conftest.py | 42 +- tests/test_data/test_transforms/__init__.py | 0 .../test_transforms/test_conditions.py | 516 +++++++++++++ tests/test_targets/test_classes.py | 299 +------- tests/test_targets/test_filtering.py | 210 ------ tests/test_targets/test_rois.py | 29 +- tests/test_targets/test_targets.py | 82 +- tests/test_targets/test_terms.py | 17 - tests/test_targets/test_transform.py | 360 --------- tests/test_train/test_labels.py | 3 +- 32 files changed, 1678 insertions(+), 2968 deletions(-) create mode 100644 src/batdetect2/data/_core.py create mode 100644 src/batdetect2/data/conditions.py create mode 100644 src/batdetect2/data/transforms.py delete mode 100644 src/batdetect2/targets/filtering.py delete mode 100644 src/batdetect2/targets/transform.py create mode 100644 tests/test_data/test_transforms/__init__.py create mode 100644 tests/test_data/test_transforms/test_conditions.py delete mode 100644 tests/test_targets/test_filtering.py delete mode 100644 tests/test_targets/test_terms.py delete mode 100644 tests/test_targets/test_transform.py diff --git a/docs/source/targets/tags_and_terms.md b/docs/source/targets/tags_and_terms.md index bd2f994..38201ed 100644 --- a/docs/source/targets/tags_and_terms.md +++ b/docs/source/targets/tags_and_terms.md @@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields: **It defaults to `class`** if you omit it, which is common when defining the main target classes. - `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`). -**Example YAML Configuration using TagInfo (e.g., inside a filter rule):** +**Example YAML Configuration (e.g., inside a filter rule):** ```yaml # ... inside a filtering configuration section ... diff --git a/example_data/config.yaml b/example_data/config.yaml index 42c3cea..083f873 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -1,44 +1,47 @@ targets: - classes: - classes: - - name: myomys - tags: - - value: Myotis mystacinus - - name: pippip - tags: - - value: Pipistrellus pipistrellus - - name: eptser - tags: - - value: Eptesicus serotinus - - name: rhifer - tags: - - value: Rhinolophus ferrumequinum - roi: - name: anchor_bbox - anchor: top-left - generic_class: + detection_target: + name: bat + match_if: + name: all_of + conditions: + - name: has_tag + tag: { key: event, value: Echolocation } + - name: not + condition: + name: has_tag + tag: { key: class, value: Unknown } + assign_tags: - key: class value: Bat - filtering: - rules: - - match_type: all - tags: - - key: event - value: Echolocation - - match_type: exclude - tags: - - key: class - value: Unknown + classification_targets: + - name: myomys + tags: + - key: class + value: Myotis mystacinus + - name: pippip + tags: + - key: class + value: Pipistrellus pipistrellus + - name: eptser + tags: + - key: class + value: Eptesicus serotinus + - name: rhifer + tags: + - key: class + value: Rhinolophus ferrumequinum + + roi: + name: anchor_bbox + anchor: top-left preprocess: audio: + samplerate: 256000 resample: - samplerate: 256000 + enabled: True method: "poly" - scale: false - center: true - duration: null spectrogram: stft: @@ -48,66 +51,63 @@ preprocess: frequencies: max_freq: 120000 min_freq: 10000 - pcen: - time_constant: 0.1 - gain: 0.98 - bias: 2 - power: 0.5 - scale: "amplitude" size: height: 128 resize_factor: 0.5 - spectral_mean_substraction: true - peak_normalize: false + transforms: + - name: pcen + time_constant: 0.1 + gain: 0.98 + bias: 2 + power: 0.5 + - name: spectral_mean_substraction postprocess: nms_kernel_size: 9 detection_threshold: 0.01 - min_freq: 10000 - max_freq: 120000 top_k_per_sec: 200 -labels: - sigma: 3 - model: input_height: 128 in_channels: 1 out_channels: 32 encoder: layers: - - block_type: FreqCoordConvDown + - name: FreqCoordConvDown out_channels: 32 - - block_type: FreqCoordConvDown + - name: FreqCoordConvDown out_channels: 64 - - block_type: LayerGroup + - name: LayerGroup layers: - - block_type: FreqCoordConvDown + - name: FreqCoordConvDown out_channels: 128 - - block_type: ConvBlock + - name: ConvBlock out_channels: 256 bottleneck: channels: 256 layers: - - block_type: SelfAttention + - name: SelfAttention attention_channels: 256 decoder: layers: - - block_type: FreqCoordConvUp + - name: FreqCoordConvUp out_channels: 64 - - block_type: FreqCoordConvUp + - name: FreqCoordConvUp out_channels: 32 - - block_type: LayerGroup + - name: LayerGroup layers: - - block_type: FreqCoordConvUp + - name: FreqCoordConvUp out_channels: 32 - - block_type: ConvBlock + - name: ConvBlock out_channels: 32 train: learning_rate: 0.001 t_max: 100 + labels: + sigma: 3 + dataloaders: train: batch_size: 8 @@ -133,33 +133,35 @@ train: weight: 0.1 logger: - logger_type: csv - save_dir: outputs/log/ - name: logs + logger_type: tensorboard + # save_dir: outputs/log/ + # name: logs augmentations: - steps: - - augmentation_type: mix_audio + enabled: true + audio: + - name: mix_audio probability: 0.2 min_weight: 0.3 max_weight: 0.7 - - augmentation_type: add_echo + - name: add_echo probability: 0.2 max_delay: 0.005 min_weight: 0.0 max_weight: 1.0 - - augmentation_type: scale_volume + spectrogram: + - name: scale_volume probability: 0.2 min_scaling: 0.0 max_scaling: 2.0 - - augmentation_type: warp + - name: warp probability: 0.2 delta: 0.04 - - augmentation_type: mask_time + - name: mask_time probability: 0.2 max_perc: 0.05 max_masks: 3 - - augmentation_type: mask_freq + - name: mask_freq probability: 0.2 max_perc: 0.10 max_masks: 3 diff --git a/example_data/dataset.yaml b/example_data/dataset.yaml index 790da8c..d96e0e5 100644 --- a/example_data/dataset.yaml +++ b/example_data/dataset.yaml @@ -6,3 +6,33 @@ sources: description: Examples included for testing batdetect2 annotations_dir: example_data/anns audio_dir: example_data/audio + +classes: + # Each class has a name + - name: rhihip + + # Can be specified by some tags + tags: + - key: species + value: Rhinolophus hipposideros + + - name: myotis + + # Or if needed by more complex conditions + condition: + name: has_any_tag + tag: + - key: species + value: Myotis myotis + - key: species + value: Myotis nattereri + - key: species + value: Myotis daubentonii + + # Specifies how to translate back the "class" into + # the tag system + # If tags are provided and output_tags are not then + # just use the provided tags + output_tags: + - key: genus + value: Myotis diff --git a/pyproject.toml b/pyproject.toml index 20a087b..771930c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "torch>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0", "torchvision>=0.14.0", - "soundevent[audio,geometry,plot]>=2.8.1", + "soundevent[audio,geometry,plot]>=2.9.1", "click>=8.1.7", "netcdf4>=1.6.5", "tqdm>=4.66.2", diff --git a/src/batdetect2/configs.py b/src/batdetect2/configs.py index 7399d6e..c7ffcd3 100644 --- a/src/batdetect2/configs.py +++ b/src/batdetect2/configs.py @@ -27,7 +27,7 @@ class BaseConfig(BaseModel): and serialization capabilities. """ - model_config = ConfigDict(extra="ignore") + model_config = ConfigDict(extra="forbid") def to_yaml_string( self, diff --git a/src/batdetect2/data/_core.py b/src/batdetect2/data/_core.py new file mode 100644 index 0000000..4ff9f7d --- /dev/null +++ b/src/batdetect2/data/_core.py @@ -0,0 +1,61 @@ +from typing import Generic, Protocol, Type, TypeVar + +from pydantic import BaseModel + +__all__ = [ + "Registry", +] + +T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) +T_Type = TypeVar("T_Type", covariant=True) + + +class LogicProtocol(Generic[T_Config, T_Type], Protocol): + """A generic protocol for the logic classes (conditions or transforms).""" + + @classmethod + def from_config(cls, config: T_Config) -> T_Type: ... + + +T_Proto = TypeVar("T_Proto", bound=LogicProtocol) + + +class Registry(Generic[T_Type]): + """A generic class to create and manage a registry of items.""" + + def __init__(self, name: str): + self._name = name + self._registry = {} + + def register(self, config_cls: Type[T_Config]): + """A decorator factory to register a new item.""" + fields = config_cls.model_fields + + if "name" not in fields: + raise ValueError("Configuration object must have a 'name' field.") + + name = fields["name"].default + + if not isinstance(name, str): + raise ValueError("'name' field must be a string literal.") + + def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]: + self._registry[name] = logic_cls + return logic_cls + + return decorator + + def build(self, config: BaseModel) -> T_Type: + """Builds a logic instance from a config object.""" + + name = getattr(config, "name") # noqa: B009 + + if name is None: + raise ValueError("Config does not have a name field") + + if name not in self._registry: + raise NotImplementedError( + f"No {self._name} with name '{name}' is registered." + ) + + return self._registry[name].from_config(config) diff --git a/src/batdetect2/data/annotations/__init__.py b/src/batdetect2/data/annotations/__init__.py index 4a3a94a..e4dc183 100644 --- a/src/batdetect2/data/annotations/__init__.py +++ b/src/batdetect2/data/annotations/__init__.py @@ -14,8 +14,9 @@ format-specific loading function to retrieve the annotations as a standard """ from pathlib import Path -from typing import Optional, Union +from typing import Annotated, Optional, Union +from pydantic import Field from soundevent import data from batdetect2.data.annotations.aoef import ( @@ -42,10 +43,13 @@ __all__ = [ ] -AnnotationFormats = Union[ - BatDetect2MergedAnnotations, - BatDetect2FilesAnnotations, - AOEFAnnotations, +AnnotationFormats = Annotated[ + Union[ + BatDetect2MergedAnnotations, + BatDetect2FilesAnnotations, + AOEFAnnotations, + ], + Field(discriminator="format"), ] """Type Alias representing all supported data source configurations. diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py new file mode 100644 index 0000000..42a59e9 --- /dev/null +++ b/src/batdetect2/data/conditions.py @@ -0,0 +1,287 @@ +from collections.abc import Callable +from typing import Annotated, List, Literal, Sequence, Union + +from pydantic import Field +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.configs import BaseConfig +from batdetect2.data._core import Registry + +SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] + +_conditions: Registry[SoundEventCondition] = Registry("condition") + + +class HasTagConfig(BaseConfig): + name: Literal["has_tag"] = "has_tag" + tag: data.Tag + + +@_conditions.register(HasTagConfig) +class HasTag: + def __init__(self, tag: data.Tag): + self.tag = tag + + def __call__( + self, sound_event_annotation: data.SoundEventAnnotation + ) -> bool: + return self.tag in sound_event_annotation.tags + + @classmethod + def from_config(cls, config: HasTagConfig): + return cls(tag=config.tag) + + +class HasAllTagsConfig(BaseConfig): + name: Literal["has_all_tags"] = "has_all_tags" + tags: List[data.Tag] + + +@_conditions.register(HasAllTagsConfig) +class HasAllTags: + def __init__(self, tags: List[data.Tag]): + if not tags: + raise ValueError("Need to specify at least one tag") + + self.tags = set(tags) + + def __call__( + self, sound_event_annotation: data.SoundEventAnnotation + ) -> bool: + return self.tags.issubset(sound_event_annotation.tags) + + @classmethod + def from_config(cls, config: HasAllTagsConfig): + return cls(tags=config.tags) + + +class HasAnyTagConfig(BaseConfig): + name: Literal["has_any_tag"] = "has_any_tag" + tags: List[data.Tag] + + +@_conditions.register(HasAnyTagConfig) +class HasAnyTag: + def __init__(self, tags: List[data.Tag]): + if not tags: + raise ValueError("Need to specify at least one tag") + + self.tags = set(tags) + + def __call__( + self, sound_event_annotation: data.SoundEventAnnotation + ) -> bool: + return bool(self.tags.intersection(sound_event_annotation.tags)) + + @classmethod + def from_config(cls, config: HasAnyTagConfig): + return cls(tags=config.tags) + + +Operator = Literal["gt", "gte", "lt", "lte", "eq"] + + +class DurationConfig(BaseConfig): + name: Literal["duration"] = "duration" + operator: Operator + seconds: float + + +def _build_comparator( + operator: Operator, value: float +) -> Callable[[float], bool]: + if operator == "gt": + return lambda x: x > value + + if operator == "gte": + return lambda x: x >= value + + if operator == "lt": + return lambda x: x < value + + if operator == "lte": + return lambda x: x <= value + + if operator == "eq": + return lambda x: x == value + + raise ValueError(f"Invalid operator {operator}") + + +@_conditions.register(DurationConfig) +class Duration: + def __init__(self, operator: Operator, seconds: float): + self.operator = operator + self.seconds = seconds + self._comparator = _build_comparator(self.operator, self.seconds) + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> bool: + geometry = sound_event_annotation.sound_event.geometry + + if geometry is None: + return False + + start_time, _, end_time, _ = compute_bounds(geometry) + duration = end_time - start_time + + return self._comparator(duration) + + @classmethod + def from_config(cls, config: DurationConfig): + return cls(operator=config.operator, seconds=config.seconds) + + +class FrequencyConfig(BaseConfig): + name: Literal["frequency"] = "frequency" + boundary: Literal["low", "high"] + operator: Operator + hertz: float + + +@_conditions.register(FrequencyConfig) +class Frequency: + def __init__( + self, + operator: Operator, + boundary: Literal["low", "high"], + hertz: float, + ): + self.operator = operator + self.hertz = hertz + self.boundary = boundary + self._comparator = _build_comparator(self.operator, self.hertz) + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> bool: + geometry = sound_event_annotation.sound_event.geometry + + if geometry is None: + return False + + # Automatically false if geometry does not have a frequency range + if isinstance(geometry, (data.TimeInterval, data.TimeStamp)): + return False + + _, low_freq, _, high_freq = compute_bounds(geometry) + + if self.boundary == "low": + return self._comparator(low_freq) + + return self._comparator(high_freq) + + @classmethod + def from_config(cls, config: FrequencyConfig): + return cls( + operator=config.operator, + boundary=config.boundary, + hertz=config.hertz, + ) + + +class AllOfConfig(BaseConfig): + name: Literal["all_of"] = "all_of" + conditions: Sequence["SoundEventConditionConfig"] + + +@_conditions.register(AllOfConfig) +class AllOf: + def __init__(self, conditions: List[SoundEventCondition]): + self.conditions = conditions + + def __call__( + self, sound_event_annotation: data.SoundEventAnnotation + ) -> bool: + return all(c(sound_event_annotation) for c in self.conditions) + + @classmethod + def from_config(cls, config: AllOfConfig): + conditions = [ + build_sound_event_condition(cond) for cond in config.conditions + ] + return cls(conditions) + + +class AnyOfConfig(BaseConfig): + name: Literal["any_of"] = "any_of" + conditions: List["SoundEventConditionConfig"] + + +@_conditions.register(AnyOfConfig) +class AnyOf: + def __init__(self, conditions: List[SoundEventCondition]): + self.conditions = conditions + + def __call__( + self, sound_event_annotation: data.SoundEventAnnotation + ) -> bool: + return any(c(sound_event_annotation) for c in self.conditions) + + @classmethod + def from_config(cls, config: AnyOfConfig): + conditions = [ + build_sound_event_condition(cond) for cond in config.conditions + ] + return cls(conditions) + + +class NotConfig(BaseConfig): + name: Literal["not"] = "not" + condition: "SoundEventConditionConfig" + + +@_conditions.register(NotConfig) +class Not: + def __init__(self, condition: SoundEventCondition): + self.condition = condition + + def __call__( + self, sound_event_annotation: data.SoundEventAnnotation + ) -> bool: + return not self.condition(sound_event_annotation) + + @classmethod + def from_config(cls, config: NotConfig): + condition = build_sound_event_condition(config.condition) + return cls(condition) + + +SoundEventConditionConfig = Annotated[ + Union[ + HasTagConfig, + HasAllTagsConfig, + HasAnyTagConfig, + DurationConfig, + FrequencyConfig, + AllOfConfig, + AnyOfConfig, + NotConfig, + ], + Field(discriminator="name"), +] + + +def build_sound_event_condition( + config: SoundEventConditionConfig, +) -> SoundEventCondition: + return _conditions.build(config) + + +def filter_clip_annotation( + clip_annotation: data.ClipAnnotation, + condition: SoundEventCondition, +) -> data.ClipAnnotation: + return clip_annotation.model_copy( + update=dict( + sound_events=[ + sound_event + for sound_event in clip_annotation.sound_events + if condition(sound_event) + ] + ) + ) diff --git a/src/batdetect2/data/datasets.py b/src/batdetect2/data/datasets.py index ab43b17..f1b5117 100644 --- a/src/batdetect2/data/datasets.py +++ b/src/batdetect2/data/datasets.py @@ -19,7 +19,7 @@ The core components are: """ from pathlib import Path -from typing import Annotated, List, Optional +from typing import List, Optional from loguru import logger from pydantic import Field @@ -31,6 +31,17 @@ from batdetect2.data.annotations import ( AnnotationFormats, load_annotated_dataset, ) +from batdetect2.data.conditions import ( + SoundEventConditionConfig, + build_sound_event_condition, + filter_clip_annotation, +) +from batdetect2.data.transforms import ( + ApplyAll, + SoundEventTransformConfig, + build_sound_event_transform, + transform_clip_annotation, +) from batdetect2.targets.terms import data_source __all__ = [ @@ -52,79 +63,68 @@ sources. class DatasetConfig(BaseConfig): - """Configuration model defining the structure of a BatDetect2 dataset. - - This class is typically loaded from a YAML file and describes the components - of the dataset, including metadata and a list of data sources. - - Attributes - ---------- - name : str - A descriptive name for the dataset (e.g., "UK_Bats_Project_2024"). - description : str - A longer description of the dataset's contents, origin, purpose, etc. - sources : List[AnnotationFormats] - A list defining the different data sources contributing to this - dataset. Each item in the list must conform to one of the Pydantic - models defined in the `AnnotationFormats` type union. The specific - model used for each source is determined by the mandatory `format` - field within the source's configuration, allowing BatDetect2 to use the - correct parser for different annotation styles. - """ + """Configuration model defining the structure of a BatDetect2 dataset.""" name: str description: str - sources: List[ - Annotated[AnnotationFormats, Field(..., discriminator="format")] - ] + sources: List[AnnotationFormats] + + sound_event_filter: Optional[SoundEventConditionConfig] = None + sound_event_transforms: List[SoundEventTransformConfig] = Field( + default_factory=list + ) def load_dataset( - dataset: DatasetConfig, + config: DatasetConfig, base_dir: Optional[Path] = None, ) -> Dataset: - """Load all clip annotations from the sources defined in a DatasetConfig. - - Iterates through each data source specified in the `dataset_config`, - delegates the loading and parsing of that source's annotations to - `batdetect2.data.annotations.load_annotated_dataset` (which handles - different data formats), and aggregates all resulting `ClipAnnotation` - objects into a single flat list. - - Parameters - ---------- - dataset_config : DatasetConfig - The configuration object describing the dataset and its sources. - base_dir : Path, optional - An optional base directory path. If provided, relative paths for - metadata files or data directories within the `dataset_config`'s - sources might be resolved relative to this directory. Defaults to None. - - Returns - ------- - Dataset (List[data.ClipAnnotation]) - A flat list containing all loaded `ClipAnnotation` metadata objects - from all specified sources. - - Raises - ------ - Exception - Can raise various exceptions during the delegated loading process - (`load_annotated_dataset`) if files are not found, cannot be parsed - according to the specified format, or other I/O errors occur. - """ + """Load all clip annotations from the sources defined in a DatasetConfig.""" clip_annotations = [] - for source in dataset.sources: + + condition = ( + build_sound_event_condition(config.sound_event_filter) + if config.sound_event_filter is not None + else None + ) + + transform = ( + ApplyAll( + [ + build_sound_event_transform(step) + for step in config.sound_event_transforms + ] + ) + if config.sound_event_transforms + else None + ) + + for source in config.sources: annotated_source = load_annotated_dataset(source, base_dir=base_dir) + logger.debug( "Loaded {num_examples} from dataset source '{source_name}'", num_examples=len(annotated_source.clip_annotations), source_name=source.name, ) - clip_annotations.extend( - insert_source_tag(clip_annotation, source) - for clip_annotation in annotated_source.clip_annotations - ) + + for clip_annotation in annotated_source.clip_annotations: + clip_annotation = insert_source_tag(clip_annotation, source) + + if condition is not None: + clip_annotation = filter_clip_annotation( + clip_annotation, + condition, + ) + + if transform is not None: + clip_annotation = transform_clip_annotation( + clip_annotation, + transform, + ) + + clip_annotations.append(clip_annotation) + return clip_annotations @@ -161,7 +161,6 @@ def insert_source_tag( ) -# TODO: add documentation def load_dataset_config(path: data.PathLike, field: Optional[str] = None): return load_config(path=path, schema=DatasetConfig, field=field) diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py new file mode 100644 index 0000000..5dc7e84 --- /dev/null +++ b/src/batdetect2/data/transforms.py @@ -0,0 +1,250 @@ +from collections.abc import Callable +from typing import Annotated, Dict, List, Literal, Optional, Union + +from pydantic import Field +from soundevent import data + +from batdetect2.configs import BaseConfig +from batdetect2.data._core import Registry +from batdetect2.data.conditions import ( + SoundEventCondition, + SoundEventConditionConfig, + build_sound_event_condition, +) + +SoundEventTransform = Callable[ + [data.SoundEventAnnotation], + data.SoundEventAnnotation, +] + +_transforms: Registry[SoundEventTransform] = Registry("transform") + + +class SetFrequencyBoundConfig(BaseConfig): + name: Literal["set_frequency"] = "set_frequency" + boundary: Literal["low", "high"] = "low" + hertz: float + + +@_transforms.register(SetFrequencyBoundConfig) +class SetFrequencyBound: + def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"): + self.hertz = hertz + self.boundary = boundary + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> data.SoundEventAnnotation: + sound_event = sound_event_annotation.sound_event + geometry = sound_event.geometry + + if geometry is None: + return sound_event_annotation + + if not isinstance(geometry, data.BoundingBox): + return sound_event_annotation + + start_time, low_freq, end_time, high_freq = geometry.coordinates + + if self.boundary == "low": + low_freq = self.hertz + high_freq = max(high_freq, low_freq) + + elif self.boundary == "high": + high_freq = self.hertz + low_freq = min(high_freq, low_freq) + + geometry = data.BoundingBox( + coordinates=[start_time, low_freq, end_time, high_freq], + ) + + sound_event = sound_event.model_copy(update=dict(geometry=geometry)) + return sound_event_annotation.model_copy( + update=dict(sound_event=sound_event) + ) + + @classmethod + def from_config(cls, config: SetFrequencyBoundConfig): + return cls(hertz=config.hertz, boundary=config.boundary) + + +class ApplyIfConfig(BaseConfig): + name: Literal["apply_if"] = "apply_if" + transform: "SoundEventTransformConfig" + condition: SoundEventConditionConfig + + +@_transforms.register(ApplyIfConfig) +class ApplyIf: + def __init__( + self, + condition: SoundEventCondition, + transform: SoundEventTransform, + ): + self.condition = condition + self.transform = transform + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> data.SoundEventAnnotation: + if not self.condition(sound_event_annotation): + return sound_event_annotation + + return self.transform(sound_event_annotation) + + @classmethod + def from_config(cls, config: ApplyIfConfig): + transform = build_sound_event_transform(config.transform) + condition = build_sound_event_condition(config.condition) + return cls(condition=condition, transform=transform) + + +class ReplaceTagConfig(BaseConfig): + name: Literal["replace_tag"] = "replace_tag" + original: data.Tag + replacement: data.Tag + + +@_transforms.register(ReplaceTagConfig) +class ReplaceTag: + def __init__( + self, + original: data.Tag, + replacement: data.Tag, + ): + self.original = original + self.replacement = replacement + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> data.SoundEventAnnotation: + tags = [] + + for tag in sound_event_annotation.tags: + if tag == self.original: + tags.append(self.replacement) + else: + tags.append(tag) + + return sound_event_annotation.model_copy(update=dict(tags=tags)) + + @classmethod + def from_config(cls, config: ReplaceTagConfig): + return cls(original=config.original, replacement=config.replacement) + + +class MapTagValueConfig(BaseConfig): + name: Literal["map_tag_value"] = "map_tag_value" + tag_key: str + value_mapping: Dict[str, str] + target_key: Optional[str] = None + + +@_transforms.register(MapTagValueConfig) +class MapTagValue: + def __init__( + self, + tag_key: str, + value_mapping: Dict[str, str], + target_key: Optional[str] = None, + ): + self.tag_key = tag_key + self.value_mapping = value_mapping + self.target_key = target_key + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> data.SoundEventAnnotation: + tags = [] + + for tag in sound_event_annotation.tags: + if tag.key != self.tag_key: + tags.append(tag) + continue + + value = self.value_mapping.get(tag.value) + + if value is None: + tags.append(tag) + continue + + if self.target_key is None: + tags.append(tag.model_copy(update=dict(value=value))) + else: + tags.append( + data.Tag( + key=self.target_key, # type: ignore + value=value, + ) + ) + + return sound_event_annotation.model_copy(update=dict(tags=tags)) + + @classmethod + def from_config(cls, config: MapTagValueConfig): + return cls( + tag_key=config.tag_key, + value_mapping=config.value_mapping, + target_key=config.target_key, + ) + + +class ApplyAllConfig(BaseConfig): + name: Literal["apply_all"] = "apply_all" + steps: List["SoundEventTransformConfig"] = Field(default_factory=list) + + +@_transforms.register(ApplyAllConfig) +class ApplyAll: + def __init__(self, steps: List[SoundEventTransform]): + self.steps = steps + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> data.SoundEventAnnotation: + for step in self.steps: + sound_event_annotation = step(sound_event_annotation) + + return sound_event_annotation + + @classmethod + def from_config(cls, config: ApplyAllConfig): + steps = [build_sound_event_transform(step) for step in config.steps] + return cls(steps) + + +SoundEventTransformConfig = Annotated[ + Union[ + SetFrequencyBoundConfig, + ReplaceTagConfig, + MapTagValueConfig, + ApplyIfConfig, + ApplyAllConfig, + ], + Field(discriminator="name"), +] + + +def build_sound_event_transform( + config: SoundEventTransformConfig, +) -> SoundEventTransform: + return _transforms.build(config) + + +def transform_clip_annotation( + clip_annotation: data.ClipAnnotation, + transform: SoundEventTransform, +) -> data.ClipAnnotation: + return clip_annotation.model_copy( + update=dict( + sound_events=[ + transform(sound_event) + for sound_event in clip_annotation.sound_events + ] + ) + ) diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index b1bbec1..6e64542 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -56,7 +56,7 @@ __all__ = [ class SelfAttentionConfig(BaseConfig): - block_type: Literal["SelfAttention"] = "SelfAttention" + name: Literal["SelfAttention"] = "SelfAttention" attention_channels: int temperature: float = 1 @@ -178,7 +178,7 @@ class SelfAttention(nn.Module): class ConvConfig(BaseConfig): """Configuration for a basic ConvBlock.""" - block_type: Literal["ConvBlock"] = "ConvBlock" + name: Literal["ConvBlock"] = "ConvBlock" """Discriminator field indicating the block type.""" out_channels: int @@ -300,7 +300,7 @@ class VerticalConv(nn.Module): class FreqCoordConvDownConfig(BaseConfig): """Configuration for a FreqCoordConvDownBlock.""" - block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown" + name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown" """Discriminator field indicating the block type.""" out_channels: int @@ -390,7 +390,7 @@ class FreqCoordConvDownBlock(nn.Module): class StandardConvDownConfig(BaseConfig): """Configuration for a StandardConvDownBlock.""" - block_type: Literal["StandardConvDown"] = "StandardConvDown" + name: Literal["StandardConvDown"] = "StandardConvDown" """Discriminator field indicating the block type.""" out_channels: int @@ -460,7 +460,7 @@ class StandardConvDownBlock(nn.Module): class FreqCoordConvUpConfig(BaseConfig): """Configuration for a FreqCoordConvUpBlock.""" - block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp" + name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp" """Discriminator field indicating the block type.""" out_channels: int @@ -569,7 +569,7 @@ class FreqCoordConvUpBlock(nn.Module): class StandardConvUpConfig(BaseConfig): """Configuration for a StandardConvUpBlock.""" - block_type: Literal["StandardConvUp"] = "StandardConvUp" + name: Literal["StandardConvUp"] = "StandardConvUp" """Discriminator field indicating the block type.""" out_channels: int @@ -664,13 +664,13 @@ LayerConfig = Annotated[ SelfAttentionConfig, "LayerGroupConfig", ], - Field(discriminator="block_type"), + Field(discriminator="name"), ] """Type alias for the discriminated union of block configuration models.""" class LayerGroupConfig(BaseConfig): - block_type: Literal["LayerGroup"] = "LayerGroup" + name: Literal["LayerGroup"] = "LayerGroup" layers: List[LayerConfig] @@ -686,7 +686,7 @@ def build_layer_from_config( parameters derived from the config and the current pipeline state (`input_height`, `in_channels`). - It uses the `block_type` field within the `config` object to determine + It uses the `name` field within the `config` object to determine which block class to instantiate. Parameters @@ -698,7 +698,7 @@ def build_layer_from_config( config : LayerConfig A Pydantic configuration object for the desired block (e.g., an instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified - by its `block_type` field. + by its `name` field. Returns ------- @@ -711,11 +711,11 @@ def build_layer_from_config( Raises ------ NotImplementedError - If the `config.block_type` does not correspond to a known block type. + If the `config.name` does not correspond to a known block type. ValueError If parameters derived from the config are invalid for the block. """ - if config.block_type == "ConvBlock": + if config.name == "ConvBlock": return ( ConvBlock( in_channels=in_channels, @@ -727,7 +727,7 @@ def build_layer_from_config( input_height, ) - if config.block_type == "FreqCoordConvDown": + if config.name == "FreqCoordConvDown": return ( FreqCoordConvDownBlock( in_channels=in_channels, @@ -740,7 +740,7 @@ def build_layer_from_config( input_height // 2, ) - if config.block_type == "StandardConvDown": + if config.name == "StandardConvDown": return ( StandardConvDownBlock( in_channels=in_channels, @@ -752,7 +752,7 @@ def build_layer_from_config( input_height // 2, ) - if config.block_type == "FreqCoordConvUp": + if config.name == "FreqCoordConvUp": return ( FreqCoordConvUpBlock( in_channels=in_channels, @@ -765,7 +765,7 @@ def build_layer_from_config( input_height * 2, ) - if config.block_type == "StandardConvUp": + if config.name == "StandardConvUp": return ( StandardConvUpBlock( in_channels=in_channels, @@ -777,7 +777,7 @@ def build_layer_from_config( input_height * 2, ) - if config.block_type == "SelfAttention": + if config.name == "SelfAttention": return ( SelfAttention( in_channels=in_channels, @@ -788,7 +788,7 @@ def build_layer_from_config( input_height, ) - if config.block_type == "LayerGroup": + if config.name == "LayerGroup": current_channels = in_channels current_height = input_height @@ -804,4 +804,4 @@ def build_layer_from_config( return nn.Sequential(*blocks), current_channels, current_height - raise NotImplementedError(f"Unknown block type {config.block_type}") + raise NotImplementedError(f"Unknown block type {config.name}") diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index c72f66a..22d1647 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -128,7 +128,7 @@ class Bottleneck(nn.Module): BottleneckLayerConfig = Annotated[ Union[SelfAttentionConfig,], - Field(discriminator="block_type"), + Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Decoder.""" diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index 270fc9d..18133ac 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[ StandardConvUpConfig, LayerGroupConfig, ], - Field(discriminator="block_type"), + Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Decoder.""" @@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig): layers : List[DecoderLayerConfig] An ordered list of configuration objects, each defining one layer or block in the decoder sequence. Each item must be a valid block - config including a `block_type` field and necessary parameters like + config including a `name` field and necessary parameters like `out_channels`. Input channels for each layer are inferred sequentially. The list must contain at least one layer. """ @@ -249,9 +249,9 @@ def build_decoder( ------ ValueError If `in_channels` or `input_height` are not positive, or if the layer - configuration is invalid (e.g., empty list, unknown `block_type`). + configuration is invalid (e.g., empty list, unknown `name`). NotImplementedError - If `build_layer_from_config` encounters an unknown `block_type`. + If `build_layer_from_config` encounters an unknown `name`. """ config = config or DEFAULT_DECODER_CONFIG diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index 9bb13e5..27b8853 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[ StandardConvDownConfig, LayerGroupConfig, ], - Field(discriminator="block_type"), + Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Encoder.""" @@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig): An ordered list of configuration objects, each defining one layer or block in the encoder sequence. Each item must be a valid block config (e.g., `ConvConfig`, `FreqCoordConvDownConfig`, - `StandardConvDownConfig`) including a `block_type` field and necessary + `StandardConvDownConfig`) including a `name` field and necessary parameters like `out_channels`. Input channels for each layer are inferred sequentially. The list must contain at least one layer. """ @@ -287,9 +287,9 @@ def build_encoder( ------ ValueError If `in_channels` or `input_height` are not positive, or if the layer - configuration is invalid (e.g., empty list, unknown `block_type`). + configuration is invalid (e.g., empty list, unknown `name`). NotImplementedError - If `build_layer_from_config` encounters an unknown `block_type`. + If `build_layer_from_config` encounters an unknown `name`. """ if in_channels <= 0 or input_height <= 0: raise ValueError("in_channels and input_height must be positive.") diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 067a965..02bad20 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -1,6 +1,14 @@ """Computes spectrograms from audio waveforms with configurable parameters.""" -from typing import Annotated, Callable, List, Literal, Optional, Union +from typing import ( + Annotated, + Callable, + List, + Literal, + Optional, + Sequence, + Union, +) import numpy as np import torch @@ -306,7 +314,7 @@ class SpectrogramConfig(BaseConfig): stft: STFTConfig = Field(default_factory=STFTConfig) frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) size: ResizeConfig = Field(default_factory=ResizeConfig) - transforms: List[SpectrogramTransform] = Field( + transforms: Sequence[SpectrogramTransform] = Field( default_factory=lambda: [ PcenConfig(), SpectralMeanSubstractionConfig(), diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index a635d51..b6eb193 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -1,53 +1,26 @@ -"""Main entry point for the BatDetect2 Target Definition subsystem. - -This package (`batdetect2.targets`) provides the tools and configurations -necessary to define precisely what the BatDetect2 model should learn to detect, -classify, and localize from audio data. It involves several conceptual steps, -managed through configuration files and culminating in an executable pipeline: - -1. **Terms (`.terms`)**: Defining vocabulary for annotation tags. -2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations. -3. **Transformation (`.transform`)**: Modifying tags (standardization, - derivation). -4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to - target position and size representations, and back. -5. **Class Definition (`.classes`)**: Mapping tags to target class names - (encoding) and mapping predicted names back to tags (decoding). - -This module exposes the key components for users to configure and utilize this -target definition pipeline, primarily through the `TargetConfig` data structure -and the `Targets` class (implementing `TargetProtocol`), which encapsulates the -configured processing steps. The main way to create a functional `Targets` -object is via the `build_targets` or `load_targets` functions. -""" +"""BatDetect2 Target Definition system.""" +from collections import Counter from typing import Iterable, List, Optional, Tuple from loguru import logger -from pydantic import Field +from pydantic import Field, field_validator from soundevent import data from batdetect2.configs import BaseConfig, load_config +from batdetect2.data.conditions import ( + SoundEventCondition, + build_sound_event_condition, +) from batdetect2.targets.classes import ( - ClassesConfig, + DEFAULT_CLASSES, + DEFAULT_GENERIC_CLASS, SoundEventDecoder, SoundEventEncoder, - TargetClass, - build_generic_class_tags, + TargetClassConfig, build_sound_event_decoder, build_sound_event_encoder, get_class_names_from_config, - load_classes_config, - load_decoder_from_config, - load_encoder_from_config, -) -from batdetect2.targets.filtering import ( - FilterConfig, - FilterRule, - SoundEventFilter, - build_sound_event_filter, - load_filter_config, - load_filter_from_config, ) from batdetect2.targets.rois import ( AnchorBBoxMapperConfig, @@ -55,106 +28,53 @@ from batdetect2.targets.rois import ( ROITargetMapper, build_roi_mapper, ) -from batdetect2.targets.terms import ( - TagInfo, - call_type, - get_tag_from_info, - individual, -) -from batdetect2.targets.transform import ( - DerivationRegistry, - DeriveTagRule, - MapValueRule, - ReplaceRule, - SoundEventTransformation, - TransformConfig, - build_transformation_from_config, - default_derivation_registry, - get_derivation, - load_transformation_config, - load_transformation_from_config, - register_derivation, -) +from batdetect2.targets.terms import call_type, individual from batdetect2.typing.targets import Position, Size, TargetProtocol __all__ = [ - "ClassesConfig", "DEFAULT_TARGET_CONFIG", - "DeriveTagRule", - "FilterConfig", - "FilterRule", - "MapValueRule", "AnchorBBoxMapperConfig", "ROITargetMapper", - "ReplaceRule", "SoundEventDecoder", "SoundEventEncoder", - "SoundEventFilter", - "SoundEventTransformation", - "TagInfo", - "TargetClass", + "TargetClassConfig", "TargetConfig", "Targets", - "TransformConfig", - "build_generic_class_tags", "build_roi_mapper", "build_sound_event_decoder", "build_sound_event_encoder", - "build_sound_event_filter", - "build_transformation_from_config", "call_type", "get_class_names_from_config", - "get_derivation", - "get_tag_from_info", "individual", - "load_classes_config", - "load_decoder_from_config", - "load_encoder_from_config", - "load_filter_config", - "load_filter_from_config", "load_target_config", - "load_transformation_config", - "load_transformation_from_config", - "register_derivation", ] class TargetConfig(BaseConfig): - """Unified configuration for the entire target definition pipeline. + detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS) - This model aggregates the configurations for semantic processing (filtering, - transformation, class definition) and geometric processing (ROI mapping). - It serves as the primary input for building a complete `Targets` object - via `build_targets` or `load_targets`. - - Attributes - ---------- - filtering : FilterConfig, optional - Configuration for filtering sound event annotations based on tags. - If None or omitted, no filtering is applied. - transforms : TransformConfig, optional - Configuration for transforming annotation tags - (mapping, derivation, etc.). If None or omitted, no tag transformations - are applied. - classes : ClassesConfig - Configuration defining the specific target classes, their tag matching - rules for encoding, their representative tags for decoding - (`output_tags`), and the definition of the generic class tags. - This section is mandatory. - roi : ROIConfig, optional - Configuration defining how geometric ROIs (e.g., bounding boxes) are - mapped to target representations (reference point, scaled size). - Controls `position`, `time_scale`, `frequency_scale`. If None or - omitted, default ROI mapping settings are used. - """ - - filtering: FilterConfig = Field(default_factory=FilterConfig) - transforms: TransformConfig = Field(default_factory=TransformConfig) - classes: ClassesConfig = Field( - default_factory=lambda: DEFAULT_CLASSES_CONFIG + classification_targets: List[TargetClassConfig] = Field( + default_factory=lambda: DEFAULT_CLASSES ) + roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig) + @field_validator("classification_targets") + def check_unique_class_names(cls, v: List[TargetClassConfig]): + """Ensure all defined class names are unique.""" + names = [c.name for c in v] + + if len(names) != len(set(names)): + name_counts = Counter(names) + duplicates = [ + name for name, count in name_counts.items() if count > 1 + ] + raise ValueError( + "Class names must be unique. Found duplicates: " + f"{', '.join(duplicates)}" + ) + return v + def load_target_config( path: data.PathLike, @@ -230,8 +150,7 @@ class Targets(TargetProtocol): roi_mapper: ROITargetMapper, class_names: list[str], generic_class_tags: List[data.Tag], - filter_fn: Optional[SoundEventFilter] = None, - transform_fn: Optional[SoundEventTransformation] = None, + filter_fn: Optional[SoundEventCondition] = None, roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None, ): """Initialize the Targets object. @@ -264,7 +183,6 @@ class Targets(TargetProtocol): self._filter_fn = filter_fn self._encode_fn = encode_fn self._decode_fn = decode_fn - self._transform_fn = transform_fn self._roi_mapper_overrides = roi_mapper_overrides or {} for class_name in self._roi_mapper_overrides: @@ -336,27 +254,6 @@ class Targets(TargetProtocol): """ return self._decode_fn(class_label) - def transform( - self, sound_event: data.SoundEventAnnotation - ) -> data.SoundEventAnnotation: - """Apply the configured tag transformations to an annotation. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation whose tags should be transformed. - - Returns - ------- - data.SoundEventAnnotation - A new annotation object with the transformed tags. If no - transformations were configured, the original annotation object is - returned. - """ - if self._transform_fn: - return self._transform_fn(sound_event) - return sound_event - def encode_roi( self, sound_event: data.SoundEventAnnotation ) -> tuple[Position, Size]: @@ -422,112 +319,14 @@ class Targets(TargetProtocol): return self._roi_mapper.decode(position, size) -DEFAULT_CLASSES = [ - TargetClass( - tags=[TagInfo(value="Myotis mystacinus")], - name="myomys", - ), - TargetClass( - tags=[TagInfo(value="Myotis alcathoe")], - name="myoalc", - ), - TargetClass( - tags=[TagInfo(value="Eptesicus serotinus")], - name="eptser", - ), - TargetClass( - tags=[TagInfo(value="Pipistrellus nathusii")], - name="pipnat", - ), - TargetClass( - tags=[TagInfo(value="Barbastellus barbastellus")], - name="barbar", - ), - TargetClass( - tags=[TagInfo(value="Myotis nattereri")], - name="myonat", - ), - TargetClass( - tags=[TagInfo(value="Myotis daubentonii")], - name="myodau", - ), - TargetClass( - tags=[TagInfo(value="Myotis brandtii")], - name="myobra", - ), - TargetClass( - tags=[TagInfo(value="Pipistrellus pipistrellus")], - name="pippip", - ), - TargetClass( - tags=[TagInfo(value="Myotis bechsteinii")], - name="myobec", - ), - TargetClass( - tags=[TagInfo(value="Pipistrellus pygmaeus")], - name="pippyg", - ), - TargetClass( - tags=[TagInfo(value="Rhinolophus hipposideros")], - name="rhihip", - ), - TargetClass( - tags=[TagInfo(value="Nyctalus leisleri")], - name="nyclei", - roi=AnchorBBoxMapperConfig(anchor="top-left"), - ), - TargetClass( - tags=[TagInfo(value="Rhinolophus ferrumequinum")], - name="rhifer", - roi=AnchorBBoxMapperConfig(anchor="top-left"), - ), - TargetClass( - tags=[TagInfo(value="Plecotus auritus")], - name="pleaur", - ), - TargetClass( - tags=[TagInfo(value="Nyctalus noctula")], - name="nycnoc", - ), - TargetClass( - tags=[TagInfo(value="Plecotus austriacus")], - name="pleaus", - ), -] - - -DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig( - classes=DEFAULT_CLASSES, - generic_class=[TagInfo(value="Bat")], -) - - DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( - filtering=FilterConfig( - rules=[ - FilterRule( - match_type="all", - tags=[TagInfo(key="event", value="Echolocation")], - ), - FilterRule( - match_type="exclude", - tags=[ - TagInfo(key="event", value="Feeding"), - TagInfo(key="event", value="Unknown"), - TagInfo(key="event", value="Not Bat"), - ], - ), - ] - ), - classes=DEFAULT_CLASSES_CONFIG, + classification_targets=DEFAULT_CLASSES, + detection_target=DEFAULT_GENERIC_CLASS, roi=AnchorBBoxMapperConfig(), ) -def build_targets( - config: Optional[TargetConfig] = None, - derivation_registry: DerivationRegistry = default_derivation_registry, -) -> Targets: +def build_targets(config: Optional[TargetConfig] = None) -> Targets: """Build a Targets object from a loaded TargetConfig. This factory function takes the unified `TargetConfig` and constructs all @@ -541,10 +340,6 @@ def build_targets( ---------- config : TargetConfig The loaded and validated unified target configuration object. - derivation_registry : DerivationRegistry, optional - The DerivationRegistry instance to use for resolving derivation - function names. Defaults to the global - `batdetect2.targets.transform.derivation_registry`. Returns ------- @@ -565,27 +360,18 @@ def build_targets( lambda: config.to_yaml_string(), ) - filter_fn = ( - build_sound_event_filter(config.filtering) - if config.filtering - else None - ) - encode_fn = build_sound_event_encoder(config.classes) - decode_fn = build_sound_event_decoder(config.classes) - transform_fn = ( - build_transformation_from_config( - config.transforms, - derivation_registry=derivation_registry, - ) - if config.transforms - else None - ) + filter_fn = build_sound_event_condition(config.detection_target.match_if) + encode_fn = build_sound_event_encoder(config.classification_targets) + decode_fn = build_sound_event_decoder(config.classification_targets) + roi_mapper = build_roi_mapper(config.roi) - class_names = get_class_names_from_config(config.classes) - generic_class_tags = build_generic_class_tags(config.classes) + class_names = get_class_names_from_config(config.classification_targets) + + generic_class_tags = config.detection_target.assign_tags + roi_overrides = { class_config.name: build_roi_mapper(class_config.roi) - for class_config in config.classes.classes + for class_config in config.classification_targets if class_config.roi is not None } @@ -596,7 +382,6 @@ def build_targets( class_names=class_names, roi_mapper=roi_mapper, generic_class_tags=generic_class_tags, - transform_fn=transform_fn, roi_mapper_overrides=roi_overrides, ) @@ -604,7 +389,6 @@ def build_targets( def load_targets( config_path: data.PathLike, field: Optional[str] = None, - derivation_registry: DerivationRegistry = default_derivation_registry, ) -> Targets: """Load a Targets object directly from a configuration file. @@ -619,9 +403,6 @@ def load_targets( field : str, optional Dot-separated path to a nested section within the file containing the target configuration. If None, the entire file content is used. - derivation_registry : DerivationRegistry, optional - The DerivationRegistry instance to use. Defaults to the global - default. Returns ------- @@ -642,7 +423,7 @@ def load_targets( config_path, field=field, ) - return build_targets(config, derivation_registry=derivation_registry) + return build_targets(config) def iterate_encoded_sound_events( @@ -658,8 +439,6 @@ def iterate_encoded_sound_events( if geometry is None: continue - sound_event = targets.transform(sound_event) - class_name = targets.encode_class(sound_event) position, size = targets.encode_roi(sound_event) diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index fed7170..96c78f6 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -1,251 +1,167 @@ -from collections import Counter -from functools import partial -from typing import Callable, Dict, List, Literal, Optional, Set, Tuple +from typing import Dict, List, Optional -from pydantic import Field, field_validator +from pydantic import Field, PrivateAttr, computed_field, model_validator from soundevent import data -from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets.rois import ROIMapperConfig -from batdetect2.targets.terms import ( - GENERIC_CLASS_KEY, - TagInfo, - get_tag_from_info, +from batdetect2.configs import BaseConfig +from batdetect2.data.conditions import ( + AllOfConfig, + HasAllTagsConfig, + HasAnyTagConfig, + HasTagConfig, + NotConfig, + SoundEventCondition, + SoundEventConditionConfig, + build_sound_event_condition, ) +from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder __all__ = [ - "DEFAULT_SPECIES_LIST", - "build_generic_class_tags", "build_sound_event_decoder", "build_sound_event_encoder", "get_class_names_from_config", - "load_classes_config", - "load_decoder_from_config", - "load_encoder_from_config", ] -DEFAULT_SPECIES_LIST = [ - "Barbastella barbastellus", - "Eptesicus serotinus", - "Myotis alcathoe", - "Myotis bechsteinii", - "Myotis brandtii", - "Myotis daubentonii", - "Myotis mystacinus", - "Myotis nattereri", - "Nyctalus leisleri", - "Nyctalus noctula", - "Pipistrellus nathusii", - "Pipistrellus pipistrellus", - "Pipistrellus pygmaeus", - "Plecotus auritus", - "Plecotus austriacus", - "Rhinolophus ferrumequinum", - "Rhinolophus hipposideros", -] -"""A default list of common bat species names found in the UK.""" - - -class TargetClass(BaseConfig): - """Defines criteria for encoding annotations and decoding predictions. - - Each instance represents one potential output class for the classification - model. It specifies: - 1. A unique `name` for the class. - 2. The tag conditions (`tags` and `match_type`) an annotation must meet to - be assigned this class name during training data preparation (encoding). - 3. An optional, alternative set of tags (`output_tags`) to be used when - converting a model's prediction of this class name back into annotation - tags (decoding). - - Attributes - ---------- - name : str - The unique name assigned to this target class (e.g., 'pippip', - 'myodau', 'noise'). This name is used as the label during model - training and is the expected output from the model's prediction. - Should be unique across all TargetClass definitions in a configuration. - tags : List[TagInfo] - A list of one or more tags (defined using `TagInfo`) used to identify - if an existing annotation belongs to this class during encoding (data - preparation for training). The `match_type` attribute determines how - these tags are evaluated. - match_type : Literal["all", "any"], default="all" - Determines how the `tags` list is evaluated during encoding: - - "all": The annotation must have *all* the tags listed to match. - - "any": The annotation must have *at least one* of the tags listed - to match. - output_tags: Optional[List[TagInfo]], default=None - An optional list of tags (defined using `TagInfo`) to be assigned to a - new annotation when the model predicts this class `name`. If `None` - (default), the tags listed in the `tags` field will be used for - decoding. If provided, this list overrides the `tags` field for the - purpose of decoding predictions back into meaningful annotation tags. - This allows, for example, training on broader categories but decoding - to more specific representative tags. - """ +class TargetClassConfig(BaseConfig): + """Defines a target class of sound events.""" name: str - tags: List[TagInfo] = Field(min_length=1) - match_type: Literal["all", "any"] = Field(default="all") - output_tags: Optional[List[TagInfo]] = None + + condition_input: Optional[SoundEventConditionConfig] = Field( + alias="match_if", + default=None, + ) + tags: Optional[List[data.Tag]] = None + + assign_tags: List[data.Tag] = Field(default_factory=list) + roi: Optional[ROIMapperConfig] = None + _match_if: SoundEventConditionConfig = PrivateAttr() -def _get_default_classes() -> List[TargetClass]: - """Generate a list of default target classes. + @model_validator(mode="after") + def _process_shorthands(self) -> "TargetClassConfig": + if self.tags and self.condition_input: + raise ValueError("Use either 'tags' or 'match_if', not both.") - Returns - ------- - List[TargetClass] - A list of TargetClass objects, one for each species in - DEFAULT_SPECIES_LIST. The class names are simplified versions of the - species names. - """ - return [ - TargetClass( - name=_get_default_class_name(value), - tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)], - ) - for value in DEFAULT_SPECIES_LIST - ] - - -def _get_default_class_name(species: str) -> str: - """Generate a default class name from a species name. - - Parameters - ---------- - species : str - The species name (e.g., "Myotis daubentonii"). - - Returns - ------- - str - A simplified class name (e.g., "myodau"). - The genus and species names are converted to lowercase, - the first three letters of each are taken, and concatenated. - """ - genus, species = species.strip().split(" ") - return f"{genus.lower()[:3]}{species.lower()[:3]}" - - -def _get_default_generic_class() -> List[TagInfo]: - """Generate the default list of TagInfo objects for the generic class. - - Provides a default set of tags used to represent the generic "Bat" category - when decoding predictions that didn't match a specific class. - - Returns - ------- - List[TagInfo] - A list containing default TagInfo objects, typically representing - `call_type: Echolocation` and `order: Chiroptera`. - """ - return [ - TagInfo(key="call_type", value="Echolocation"), - TagInfo(key="order", value="Chiroptera"), - ] - - -class ClassesConfig(BaseConfig): - """Configuration defining target classes and the generic fallback category. - - Holds the ordered list of specific target class definitions (`TargetClass`) - and defines the tags representing the generic category for sounds that pass - filtering but do not match any specific class. - - The order of `TargetClass` objects in the `classes` list defines the - priority for classification during encoding. The system checks annotations - against these definitions sequentially and assigns the name of the *first* - matching class. - - Attributes - ---------- - classes : List[TargetClass] - An ordered list of specific target class definitions. The order - determines matching priority (first match wins). Defaults to a - standard set of classes via `get_default_classes`. - generic_class : List[TagInfo] - A list of tags defining the "generic" or "unclassified but relevant" - category (e.g., representing a generic 'Bat' call that wasn't - assigned to a specific species). These tags are typically assigned - during decoding when a sound event was detected and passed filtering - but did not match any specific class rule defined in the `classes` list. - Defaults to a standard set of tags via `get_default_generic_class`. - - Raises - ------ - ValueError - If validation fails (e.g., non-unique class names in the `classes` - list). - - Notes - ----- - - It is crucial that the `name` attribute of each `TargetClass` in the - `classes` list is unique. This configuration includes a validator to - enforce this uniqueness. - - The `generic_class` tags provide a baseline identity for relevant sounds - that don't fit into more specific defined categories. - """ - - classes: List[TargetClass] = Field(default_factory=_get_default_classes) - - generic_class: List[TagInfo] = Field( - default_factory=_get_default_generic_class - ) - - @field_validator("classes") - def check_unique_class_names(cls, v: List[TargetClass]): - """Ensure all defined class names are unique.""" - names = [c.name for c in v] - - if len(names) != len(set(names)): - name_counts = Counter(names) - duplicates = [ - name for name, count in name_counts.items() if count > 1 - ] + if self.condition_input: + final_condition = self.condition_input + elif self.tags: + final_condition = HasAllTagsConfig(tags=self.tags) + else: raise ValueError( - "Class names must be unique. Found duplicates: " - f"{', '.join(duplicates)}" + f"Class '{self.name}' must have a 'tags' or 'match_if' rule." ) - return v + + self._match_if = final_condition + return self + + @computed_field + @property + def match_if(self) -> SoundEventConditionConfig: + return self._match_if -def is_target_class( - sound_event_annotation: data.SoundEventAnnotation, - tags: Set[data.Tag], - match_all: bool = True, -) -> bool: - """Check if a sound event annotation matches a set of required tags. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to check. - required_tags : Set[data.Tag] - A set of `soundevent.data.Tag` objects that define the class criteria. - match_all : bool, default=True - If True, checks if *all* `required_tags` are present in the - annotation's tags (subset check). If False, checks if *at least one* - of the `required_tags` is present (intersection check). - - Returns - ------- - bool - True if the annotation meets the tag criteria, False otherwise. - """ - annotation_tags = set(sound_event_annotation.tags) - - if match_all: - return tags <= annotation_tags - - return bool(tags & annotation_tags) +DEFAULT_GENERIC_CLASS = TargetClassConfig( + name="bat", + match_if=AllOfConfig( + conditions=[ + HasTagConfig(tag=data.Tag(key="event", value="Echolocation")), + NotConfig( + condition=HasAnyTagConfig( + tags=[ + data.Tag(key="event", value="Feeding"), + data.Tag(key="event", value="Unknown"), + data.Tag(key="event", value="Not Bat"), + ] + ) + ), + ] + ), + assign_tags=[ + data.Tag(key="call_type", value="Echolocation"), + data.Tag(key="order", value="Chiroptera"), + ], +) -def get_class_names_from_config(config: ClassesConfig) -> List[str]: +DEFAULT_CLASSES = [ + TargetClassConfig( + name="myomys", + tags=[data.Tag(key="class", value="Myotis mystacinus")], + ), + TargetClassConfig( + name="myoalc", + tags=[data.Tag(key="class", value="Myotis alcathoe")], + ), + TargetClassConfig( + name="eptser", + tags=[data.Tag(key="class", value="Eptesicus serotinus")], + ), + TargetClassConfig( + name="pipnat", + tags=[data.Tag(key="class", value="Pipistrellus nathusii")], + ), + TargetClassConfig( + name="barbar", + tags=[data.Tag(key="class", value="Barbastellus barbastellus")], + ), + TargetClassConfig( + name="myonat", + tags=[data.Tag(key="class", value="Myotis nattereri")], + ), + TargetClassConfig( + name="myodau", + tags=[data.Tag(key="class", value="Myotis daubentonii")], + ), + TargetClassConfig( + name="myobra", + tags=[data.Tag(key="class", value="Myotis brandtii")], + ), + TargetClassConfig( + name="pippip", + tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")], + ), + TargetClassConfig( + name="myobec", + tags=[data.Tag(key="class", value="Myotis bechsteinii")], + ), + TargetClassConfig( + name="pippyg", + tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")], + ), + TargetClassConfig( + name="rhihip", + tags=[data.Tag(key="class", value="Rhinolophus hipposideros")], + roi=AnchorBBoxMapperConfig(anchor="top-left"), + ), + TargetClassConfig( + name="nyclei", + tags=[data.Tag(key="class", value="Nyctalus leisleri")], + ), + TargetClassConfig( + name="rhifer", + tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")], + roi=AnchorBBoxMapperConfig(anchor="top-left"), + ), + TargetClassConfig( + name="pleaur", + tags=[data.Tag(key="class", value="Plecotus auritus")], + ), + TargetClassConfig( + name="nycnoc", + tags=[data.Tag(key="class", value="Nyctalus noctula")], + ), + TargetClassConfig( + name="pleaus", + tags=[data.Tag(key="class", value="Plecotus austriacus")], + ), +] + + +def get_class_names_from_config(configs: List[TargetClassConfig]) -> List[str]: """Extract the list of class names from a ClassesConfig object. Parameters @@ -258,324 +174,60 @@ def get_class_names_from_config(config: ClassesConfig) -> List[str]: List[str] An ordered list of unique class names defined in the configuration. """ - return [class_info.name for class_info in config.classes] + return [class_info.name for class_info in configs] -def _encode_with_multiple_classifiers( - sound_event_annotation: data.SoundEventAnnotation, - classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]], -) -> Optional[str]: - """Encode an annotation by checking against a list of classifiers. +def build_sound_event_encoder( + configs: List[TargetClassConfig], +) -> SoundEventEncoder: + """Build a sound event encoder function from the classes configuration.""" + conditions = { + class_config.name: build_sound_event_condition(class_config.match_if) + for class_config in configs + } - Internal helper function used by the `SoundEventEncoder`. It iterates - through the provided list of (class_name, classifier_function) pairs. - Returns the name associated with the first classifier function that - returns True for the given annotation. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to encode. - classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]] - An ordered list where each tuple contains a class name and a function - that returns True if the annotation matches that class. The order - determines priority. - - Returns - ------- - str or None - The name of the first matching class, or None if no classifier matches. - """ - for class_name, classifier in classifiers: - if classifier(sound_event_annotation): - return class_name - - return None + return SoundEventClassifier(conditions) -def build_sound_event_encoder(config: ClassesConfig) -> SoundEventEncoder: - """Build a sound event encoder function from the classes configuration. +class SoundEventClassifier: + def __init__(self, mapping: Dict[str, SoundEventCondition]): + self.mapping = mapping - The returned encoder function iterates through the class definitions in the - order specified in the config. It assigns an annotation the name of the - first class definition it matches. - - Parameters - ---------- - config : ClassesConfig - The loaded and validated classes configuration object. - term_registry : TermRegistry, optional - The TermRegistry instance used to look up term keys specified in the - `TagInfo` objects within the configuration. Defaults to the global - `batdetect2.targets.terms.registry`. - - Returns - ------- - SoundEventEncoder - A callable function that takes a `SoundEventAnnotation` and returns - an optional string representing the matched class name, or None if no - class matches. - - Raises - ------ - KeyError - If a term key specified in the configuration is not found in the - provided `term_registry`. - """ - binary_classifiers = [ - ( - class_info.name, - partial( - is_target_class, - tags={ - get_tag_from_info(tag_info) for tag_info in class_info.tags - }, - match_all=class_info.match_type == "all", - ), - ) - for class_info in config.classes - ] - - return partial( - _encode_with_multiple_classifiers, - classifiers=binary_classifiers, - ) - - -def _decode_class( - name: str, - mapping: Dict[str, List[data.Tag]], - raise_on_error: bool = True, -) -> List[data.Tag]: - """Decode a class name into a list of representative tags using a mapping. - - Internal helper function used by the `SoundEventDecoder`. Looks up the - provided class `name` in the `mapping` dictionary. - - Parameters - ---------- - name : str - The class name to decode. - mapping : Dict[str, List[data.Tag]] - A dictionary mapping class names to lists of `soundevent.data.Tag` - objects. - raise_on_error : bool, default=True - If True, raises a ValueError if the `name` is not found in the - `mapping`. If False, returns an empty list if the `name` is not found. - - Returns - ------- - List[data.Tag] - The list of tags associated with the class name, or an empty list if - not found and `raise_on_error` is False. - - Raises - ------ - ValueError - If `name` is not found in `mapping` and `raise_on_error` is True. - """ - if name not in mapping and raise_on_error: - raise ValueError(f"Class {name} not found in mapping.") - - if name not in mapping: - return [] - - return mapping[name] + def __call__( + self, sound_event_annotation: data.SoundEventAnnotation + ) -> Optional[str]: + for name, condition in self.mapping.items(): + if condition(sound_event_annotation): + return name def build_sound_event_decoder( - config: ClassesConfig, + configs: List[TargetClassConfig], raise_on_unmapped: bool = False, ) -> SoundEventDecoder: - """Build a sound event decoder function from the classes configuration. - - Creates a callable `SoundEventDecoder` that maps a class name string - back to a list of representative `soundevent.data.Tag` objects based on - the `ClassesConfig`. It uses the `output_tags` field if provided in a - `TargetClass`, otherwise falls back to the `tags` field. - - Parameters - ---------- - config : ClassesConfig - The loaded and validated classes configuration object. - term_registry : TermRegistry, optional - The TermRegistry instance used to look up term keys. Defaults to the - global `batdetect2.targets.terms.registry`. - raise_on_unmapped : bool, default=False - If True, the returned decoder function will raise a ValueError if asked - to decode a class name that is not in the configuration. If False, it - will return an empty list for unmapped names. - - Returns - ------- - SoundEventDecoder - A callable function that takes a class name string and returns a list - of `soundevent.data.Tag` objects. - - Raises - ------ - KeyError - If a term key specified in the configuration (`output_tags`, `tags`, or - `generic_class`) is not found in the provided `term_registry`. - """ - mapping = {} - for class_info in config.classes: - tags_to_use = ( - class_info.output_tags - if class_info.output_tags is not None - else class_info.tags - ) - mapping[class_info.name] = [ - get_tag_from_info(tag_info) for tag_info in tags_to_use - ] - - return partial( - _decode_class, - mapping=mapping, - raise_on_error=raise_on_unmapped, - ) + """Build a sound event decoder function from the classes configuration.""" + mapping = { + class_config.name: class_config.assign_tags for class_config in configs + } + return TagDecoder(mapping, raise_on_unknown=raise_on_unmapped) -def build_generic_class_tags(config: ClassesConfig) -> List[data.Tag]: - """Extract and build the list of tags for the generic class from config. +class TagDecoder: + def __init__( + self, + mapping: Dict[str, List[data.Tag]], + raise_on_unknown: bool = True, + ): + self.mapping = mapping + self.raise_on_unknown = raise_on_unknown - Converts the list of `TagInfo` objects defined in `config.generic_class` - into a list of `soundevent.data.Tag` objects using the term registry. + def __call__(self, class_name: str) -> List[data.Tag]: + tags = self.mapping.get(class_name) - Parameters - ---------- - config : ClassesConfig - The loaded classes configuration object. - term_registry : TermRegistry, optional - The TermRegistry instance for term lookups. Defaults to the global - `batdetect2.targets.terms.registry`. + if tags is None: + if self.raise_on_unknown: + raise ValueError("Invalid class name") - Returns - ------- - List[data.Tag] - The list of fully constructed tags representing the generic class. + tags = [] - Raises - ------ - KeyError - If a term key specified in `config.generic_class` is not found in the - provided `term_registry`. - """ - return [get_tag_from_info(tag_info) for tag_info in config.generic_class] - - -def load_classes_config( - path: data.PathLike, - field: Optional[str] = None, -) -> ClassesConfig: - """Load the target classes configuration from a file. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (YAML). - field : str, optional - If the classes configuration is nested under a specific key in the - file, specify the key here. Defaults to None. - - Returns - ------- - ClassesConfig - The loaded and validated classes configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the ClassesConfig schema - or if class names are not unique. - """ - return load_config(path, schema=ClassesConfig, field=field) - - -def load_encoder_from_config( - path: data.PathLike, field: Optional[str] = None -) -> SoundEventEncoder: - """Load a class encoder function directly from a configuration file. - - This is a convenience function that combines loading the `ClassesConfig` - from a file and building the final `SoundEventEncoder` function. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (e.g., YAML). - field : str, optional - If the classes configuration is nested under a specific key in the - file, specify the key here. Defaults to None. - term_registry : TermRegistry, optional - The TermRegistry instance used for term lookups. Defaults to the - global `batdetect2.targets.terms.registry`. - - Returns - ------- - SoundEventEncoder - The final encoder function ready to classify annotations. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the ClassesConfig schema - or if class names are not unique. - KeyError - If a term key specified in the configuration is not found in the - provided `term_registry` during the build process. - """ - config = load_classes_config(path, field=field) - return build_sound_event_encoder(config) - - -def load_decoder_from_config( - path: data.PathLike, - field: Optional[str] = None, - raise_on_unmapped: bool = False, -) -> SoundEventDecoder: - """Load a class decoder function directly from a configuration file. - - This is a convenience function that combines loading the `ClassesConfig` - from a file and building the final `SoundEventDecoder` function. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (e.g., YAML). - field : str, optional - If the classes configuration is nested under a specific key in the - file, specify the key here. Defaults to None. - term_registry : TermRegistry, optional - The TermRegistry instance used for term lookups. Defaults to the - global `batdetect2.targets.terms.registry`. - raise_on_unmapped : bool, default=False - If True, the returned decoder function will raise a ValueError if asked - to decode a class name that is not in the configuration. If False, it - will return an empty list for unmapped names. - - Returns - ------- - SoundEventDecoder - The final decoder function ready to convert class names back into tags. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the ClassesConfig schema - or if class names are not unique. - KeyError - If a term key specified in the configuration is not found in the - provided `term_registry` during the build process. - """ - config = load_classes_config(path, field=field) - return build_sound_event_decoder( - config, - raise_on_unmapped=raise_on_unmapped, - ) + return tags diff --git a/src/batdetect2/targets/filtering.py b/src/batdetect2/targets/filtering.py deleted file mode 100644 index 462f7e4..0000000 --- a/src/batdetect2/targets/filtering.py +++ /dev/null @@ -1,293 +0,0 @@ -import logging -from functools import partial -from typing import List, Literal, Optional, Set - -from pydantic import Field -from soundevent import data - -from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets.terms import ( - TagInfo, - get_tag_from_info, -) -from batdetect2.typing.targets import SoundEventFilter - -__all__ = [ - "FilterConfig", - "FilterRule", - "build_sound_event_filter", - "build_filter_from_rule", - "load_filter_config", - "load_filter_from_config", -] - - -logger = logging.getLogger(__name__) - - -class FilterRule(BaseConfig): - """Defines a single rule for filtering sound event annotations. - - Based on the `match_type`, this rule checks if the tags associated with a - sound event annotation meet certain criteria relative to the `tags` list - defined in this rule. - - Attributes - ---------- - match_type : Literal["any", "all", "exclude", "equal"] - Determines how the `tags` list is used: - - "any": Pass if the annotation has at least one tag from the list. - - "all": Pass if the annotation has all tags from the list (it can - have others too). - - "exclude": Pass if the annotation has none of the tags from the list. - - "equal": Pass if the annotation's tags are exactly the same set as - provided in the list. - tags : List[TagInfo] - A list of tags (defined using TagInfo for configuration) that this - rule operates on. - """ - - match_type: Literal["any", "all", "exclude", "equal"] - tags: List[TagInfo] - - -def has_any_tag( - sound_event_annotation: data.SoundEventAnnotation, - tags: Set[data.Tag], -) -> bool: - """Check if the annotation has at least one of the specified tags. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to check. - tags : Set[data.Tag] - The set of tags to look for. - - Returns - ------- - bool - True if the annotation has one or more tags from the specified set, - False otherwise. - """ - sound_event_tags = set(sound_event_annotation.tags) - return bool(tags & sound_event_tags) - - -def contains_tags( - sound_event_annotation: data.SoundEventAnnotation, - tags: Set[data.Tag], -) -> bool: - """Check if the annotation contains all of the specified tags. - - The annotation may have additional tags beyond those specified. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to check. - tags : Set[data.Tag] - The set of tags that must all be present in the annotation. - - Returns - ------- - bool - True if the annotation's tags are a superset of the specified tags, - False otherwise. - """ - sound_event_tags = set(sound_event_annotation.tags) - return tags <= sound_event_tags - - -def does_not_have_tags( - sound_event_annotation: data.SoundEventAnnotation, - tags: Set[data.Tag], -): - """Check if the annotation has none of the specified tags. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to check. - tags : Set[data.Tag] - The set of tags that must *not* be present in the annotation. - - Returns - ------- - bool - True if the annotation has zero tags in common with the specified set, - False otherwise. - """ - return not has_any_tag(sound_event_annotation, tags) - - -def equal_tags( - sound_event_annotation: data.SoundEventAnnotation, - tags: Set[data.Tag], -) -> bool: - """Check if the annotation's tags are exactly equal to the specified set. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to check. - tags : Set[data.Tag] - The exact set of tags the annotation must have. - - Returns - ------- - bool - True if the annotation's tags set is identical to the specified set, - False otherwise. - """ - sound_event_tags = set(sound_event_annotation.tags) - return tags == sound_event_tags - - -def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter: - """Creates a callable filter function from a single FilterRule. - - Parameters - ---------- - rule : FilterRule - The filter rule configuration object. - - Returns - ------- - SoundEventFilter - A function that takes a SoundEventAnnotation and returns True if it - passes the rule, False otherwise. - - Raises - ------ - ValueError - If the rule contains an invalid `match_type`. - """ - tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags} - - if rule.match_type == "any": - return partial(has_any_tag, tags=tag_set) - - if rule.match_type == "all": - return partial(contains_tags, tags=tag_set) - - if rule.match_type == "exclude": - return partial(does_not_have_tags, tags=tag_set) - - if rule.match_type == "equal": - return partial(equal_tags, tags=tag_set) - - raise ValueError( - f"Invalid match type {rule.match_type}. Valid types " - "are: 'any', 'all', 'exclude' and 'equal'" - ) - - -def _passes_all_filters( - sound_event_annotation: data.SoundEventAnnotation, - filters: List[SoundEventFilter], -) -> bool: - """Check if the annotation passes all provided filters. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to check. - filters : List[SoundEventFilter] - A list of filter functions to apply. - - Returns - ------- - bool - True if the annotation passes all filters, False otherwise. - """ - for filter_fn in filters: - if not filter_fn(sound_event_annotation): - logging.debug( - f"Sound event annotation {sound_event_annotation.uuid} " - f"excluded due to rule {filter_fn}", - ) - return False - - return True - - -class FilterConfig(BaseConfig): - """Configuration model for defining a list of filter rules. - - Attributes - ---------- - rules : List[FilterRule] - A list of FilterRule objects. An annotation must pass all rules in - this list to be considered valid by the filter built from this config. - """ - - rules: List[FilterRule] = Field(default_factory=list) - - -def build_sound_event_filter( - config: FilterConfig, -) -> SoundEventFilter: - """Builds a merged filter function from a FilterConfig object. - - Creates individual filter functions for each rule in the configuration - and merges them using AND logic. - - Parameters - ---------- - config : FilterConfig - The configuration object containing the list of filter rules. - - Returns - ------- - SoundEventFilter - A single callable filter function that applies all defined rules. - """ - filters = [build_filter_from_rule(rule) for rule in config.rules] - return partial(_passes_all_filters, filters=filters) - - -def load_filter_config( - path: data.PathLike, field: Optional[str] = None -) -> FilterConfig: - """Loads the filter configuration from a file. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (YAML). - field : Optional[str], optional - If the filter configuration is nested under a specific key in the - file, specify the key here. Defaults to None. - - Returns - ------- - FilterConfig - The loaded and validated filter configuration object. - """ - return load_config(path, schema=FilterConfig, field=field) - - -def load_filter_from_config( - path: data.PathLike, field: Optional[str] = None, -) -> SoundEventFilter: - """Loads filter configuration from a file and builds the filter function. - - This is a convenience function that combines loading the configuration - and building the final callable filter function. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (YAML). - field : Optional[str], optional - If the filter configuration is nested under a specific key in the - file, specify the key here. Defaults to None. - - Returns - ------- - SoundEventFilter - The final merged filter function ready to be used. - """ - config = load_filter_config(path=path, field=field) - return build_sound_event_filter(config) diff --git a/src/batdetect2/targets/terms.py b/src/batdetect2/targets/terms.py index 88b3576..7904b53 100644 --- a/src/batdetect2/targets/terms.py +++ b/src/batdetect2/targets/terms.py @@ -1,23 +1,11 @@ -"""Manages the vocabulary for defining training targets. +"""Manages the vocabulary for defining training targets.""" -This module provides the necessary tools to declare, register, and manage the -set of `soundevent.data.Term` objects used throughout the `batdetect2.targets` -sub-package. It establishes a consistent vocabulary for filtering, -transforming, and classifying sound events based on their annotations (Tags). - -Terms can be pre-defined, loaded from the `soundevent.terms` library or defined -programmatically. -""" - -from pydantic import BaseModel from soundevent import data, terms __all__ = [ "call_type", "individual", "data_source", - "get_tag_from_info", - "TagInfo", ] # The default key used to reference the 'generic_class' term. @@ -85,54 +73,3 @@ terms.register_term_set( ), override_existing=True, ) - - -class TagInfo(BaseModel): - """Represents information needed to define a specific Tag. - - This model is typically used in configuration files (e.g., YAML) to - specify tags used for filtering, target class definition, or associating - tags with output classes. It links a tag value to a term definition - via the term's registry key. - - Attributes - ---------- - value : str - The value of the tag (e.g., "Myotis myotis", "Echolocation"). - key : str, default="class" - The key (alias) of the term associated with this tag. Defaults to - "class", implying it represents a classification target label by - default. - """ - - value: str - key: str = GENERIC_CLASS_KEY - - -def get_tag_from_info(tag_info: TagInfo) -> data.Tag: - """Creates a soundevent.data.Tag object from TagInfo data. - - Looks up the term using the key in the provided `tag_info` and constructs a - Tag object. - - Parameters - ---------- - tag_info : TagInfo - The TagInfo object containing the value and term key. - - Returns - ------- - soundevent.data.Tag - A soundevent.data.Tag object corresponding to the input info. - - Raises - ------ - KeyError - If the term key specified in `tag_info.key` is not found. - """ - term = terms.get_term(tag_info.key) - - if not term: - raise KeyError(f"Key {tag_info.key} not found") - - return data.Tag(term=term, value=tag_info.value) diff --git a/src/batdetect2/targets/transform.py b/src/batdetect2/targets/transform.py deleted file mode 100644 index b71e658..0000000 --- a/src/batdetect2/targets/transform.py +++ /dev/null @@ -1,689 +0,0 @@ -import importlib -from functools import partial -from typing import ( - Annotated, - Callable, - Dict, - List, - Literal, - Mapping, - Optional, - Union, -) - -from pydantic import Field -from soundevent import data, terms - -from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets.terms import TagInfo, get_tag_from_info - -__all__ = [ - "DerivationRegistry", - "DeriveTagRule", - "MapValueRule", - "ReplaceRule", - "SoundEventTransformation", - "TransformConfig", - "build_transform_from_rule", - "build_transformation_from_config", - "default_derivation_registry", - "get_derivation", - "load_transformation_config", - "load_transformation_from_config", - "register_derivation", -] - -SoundEventTransformation = Callable[ - [data.SoundEventAnnotation], data.SoundEventAnnotation -] -"""Type alias for a sound event transformation function. - -A function that accepts a sound event annotation object and returns a -(potentially) modified sound event annotation object. Transformations -should generally return a copy of the annotation rather than modifying -it in place. -""" - - -Derivation = Callable[[str], str] -"""Type alias for a derivation function. - -A function that accepts a single string (typically a tag value) and returns -a new string (the derived value). -""" - - -class MapValueRule(BaseConfig): - """Configuration for mapping specific values of a source term. - - This rule replaces tags matching a specific term and one of the - original values with a new tag (potentially having a different term) - containing the corresponding replacement value. Useful for standardizing - or grouping tag values. - - Attributes - ---------- - rule_type : Literal["map_value"] - Discriminator field identifying this rule type. - source_term_key : str - The key (registered in `TermRegistry`) of the term whose tags' values - should be checked against the `value_mapping`. - value_mapping : Dict[str, str] - A dictionary mapping original string values to replacement string - values. Only tags whose value is a key in this dictionary will be - affected. - target_term_key : str, optional - The key (registered in `TermRegistry`) for the term of the *output* - tag. If None (default), the output tag uses the same term as the - source (`source_term_key`). If provided, the term of the affected - tag is changed to this target term upon replacement. - """ - - rule_type: Literal["map_value"] = "map_value" - source_term_key: str - value_mapping: Dict[str, str] - target_term_key: Optional[str] = None - - -def map_value_transform( - sound_event_annotation: data.SoundEventAnnotation, - source_term: data.Term, - target_term: data.Term, - mapping: Dict[str, str], -) -> data.SoundEventAnnotation: - """Apply a value mapping transformation to an annotation's tags. - - Iterates through the annotation's tags. If a tag matches the `source_term` - and its value is found in the `mapping`, it is replaced by a new tag with - the `target_term` and the mapped value. Other tags are kept unchanged. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to transform. - source_term : data.Term - The term of tags whose values should be mapped. - target_term : data.Term - The term to use for the newly created tags after mapping. - mapping : Dict[str, str] - The dictionary mapping original values to new values. - - Returns - ------- - data.SoundEventAnnotation - A new annotation object with the transformed tags. - """ - tags = [] - - for tag in sound_event_annotation.tags: - if tag.term != source_term or tag.value not in mapping: - tags.append(tag) - continue - - new_value = mapping[tag.value] - tags.append(data.Tag(term=target_term, value=new_value)) - - return sound_event_annotation.model_copy(update=dict(tags=tags)) - - -class DeriveTagRule(BaseConfig): - """Configuration for deriving a new tag from an existing tag's value. - - This rule applies a specified function (`derivation_function`) to the - value of tags matching the `source_term_key`. It then adds a new tag - with the `target_term_key` and the derived value. - - Attributes - ---------- - rule_type : Literal["derive_tag"] - Discriminator field identifying this rule type. - source_term_key : str - The key (registered in `TermRegistry`) of the term whose tag values - will be used as input to the derivation function. - derivation_function : str - The name/key identifying the derivation function to use. This can be - a key registered in the `DerivationRegistry` or, if - `import_derivation` is True, a full Python path like - `'my_module.my_submodule.my_function'`. - target_term_key : str, optional - The key (registered in `TermRegistry`) for the term of the new tag - that will be created with the derived value. If None (default), the - derived tag uses the same term as the source (`source_term_key`), - effectively performing an in-place value transformation. - import_derivation : bool, default=False - If True, treat `derivation_function` as a Python import path and - attempt to dynamically import it if not found in the registry. - Requires the function to be accessible in the Python environment. - keep_source : bool, default=True - If True, the original source tag (whose value was used for derivation) - is kept in the annotation's tag list alongside the newly derived tag. - If False, the original source tag is removed. - """ - - rule_type: Literal["derive_tag"] = "derive_tag" - source_term_key: str - derivation_function: str - target_term_key: Optional[str] = None - import_derivation: bool = False - keep_source: bool = True - - -def derivation_tag_transform( - sound_event_annotation: data.SoundEventAnnotation, - source_term: data.Term, - target_term: data.Term, - derivation: Derivation, - keep_source: bool = True, -) -> data.SoundEventAnnotation: - """Apply a derivation transformation to an annotation's tags. - - Iterates through the annotation's tags. For each tag matching the - `source_term`, its value is passed to the `derivation` function. - A new tag is created with the `target_term` and the derived value, - and added to the output tag list. The original source tag is kept - or discarded based on `keep_source`. Other tags are kept unchanged. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to transform. - source_term : data.Term - The term of tags whose values serve as input for the derivation. - target_term : data.Term - The term to use for the newly created derived tags. - derivation : Derivation - The function to apply to the source tag's value. - keep_source : bool, default=True - Whether to keep the original source tag in the output. - - Returns - ------- - data.SoundEventAnnotation - A new annotation object with the transformed tags (including derived - ones). - """ - tags = [] - - for tag in sound_event_annotation.tags: - if tag.term != source_term: - tags.append(tag) - continue - - if keep_source: - tags.append(tag) - - new_value = derivation(tag.value) - tags.append(data.Tag(term=target_term, value=new_value)) - - return sound_event_annotation.model_copy(update=dict(tags=tags)) - - -class ReplaceRule(BaseConfig): - """Configuration for exactly replacing one specific tag with another. - - This rule looks for an exact match of the `original` tag (both term and - value) and replaces it with the specified `replacement` tag. - - Attributes - ---------- - rule_type : Literal["replace"] - Discriminator field identifying this rule type. - original : TagInfo - The exact tag to search for, defined using its value and term key. - replacement : TagInfo - The tag to substitute in place of the original tag, defined using - its value and term key. - """ - - rule_type: Literal["replace"] = "replace" - original: TagInfo - replacement: TagInfo - - -def replace_tag_transform( - sound_event_annotation: data.SoundEventAnnotation, - source: data.Tag, - target: data.Tag, -) -> data.SoundEventAnnotation: - """Apply an exact tag replacement transformation. - - Iterates through the annotation's tags. If a tag exactly matches the - `source` tag, it is replaced by the `target` tag. Other tags are kept - unchanged. - - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to transform. - source : data.Tag - The exact tag to find and replace. - target : data.Tag - The tag to replace the source tag with. - - Returns - ------- - data.SoundEventAnnotation - A new annotation object with the replaced tag (if found). - """ - tags = [] - - for tag in sound_event_annotation.tags: - if tag == source: - tags.append(target) - else: - tags.append(tag) - - return sound_event_annotation.model_copy(update=dict(tags=tags)) - - -class TransformConfig(BaseConfig): - """Configuration model for defining a sequence of transformation rules. - - Attributes - ---------- - rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]] - A list of transformation rules to apply. The rules are applied - sequentially in the order they appear in the list. The output of - one rule becomes the input for the next. The `rule_type` field - discriminates between the different rule models. - """ - - rules: List[ - Annotated[ - Union[ReplaceRule, MapValueRule, DeriveTagRule], - Field(discriminator="rule_type"), - ] - ] = Field( - default_factory=list, - ) - - -class DerivationRegistry(Mapping[str, Derivation]): - """A registry for managing named derivation functions. - - Derivation functions are callables that take a string value and return - a transformed string value, used by `DeriveTagRule`. This registry - allows functions to be registered with a key and retrieved later. - """ - - def __init__(self): - """Initialize an empty DerivationRegistry.""" - self._derivations: Dict[str, Derivation] = {} - - def __getitem__(self, key: str) -> Derivation: - """Retrieve a derivation function by key.""" - return self._derivations[key] - - def __len__(self) -> int: - """Return the number of registered derivation functions.""" - return len(self._derivations) - - def __iter__(self): - """Return an iterator over the keys of registered functions.""" - return iter(self._derivations) - - def register(self, key: str, derivation: Derivation) -> None: - """Register a derivation function with a unique key. - - Parameters - ---------- - key : str - The unique key to associate with the derivation function. - derivation : Derivation - The callable derivation function (takes str, returns str). - - Raises - ------ - KeyError - If a derivation function with the same key is already registered. - """ - if key in self._derivations: - raise KeyError( - f"A derivation with the provided key {key} already exists" - ) - - self._derivations[key] = derivation - - def get_derivation(self, key: str) -> Derivation: - """Retrieve a derivation function by its registered key. - - Parameters - ---------- - key : str - The key of the derivation function to retrieve. - - Returns - ------- - Derivation - The requested derivation function. - - Raises - ------ - KeyError - If no derivation function with the specified key is registered. - """ - try: - return self._derivations[key] - except KeyError as err: - raise KeyError( - f"No derivation with key {key} is registered." - ) from err - - def get_keys(self) -> List[str]: - """Get a list of all registered derivation function keys. - - Returns - ------- - List[str] - The keys of all registered functions. - """ - return list(self._derivations.keys()) - - def get_derivations(self) -> List[Derivation]: - """Get a list of all registered derivation functions. - - Returns - ------- - List[Derivation] - The registered derivation function objects. - """ - return list(self._derivations.values()) - - -default_derivation_registry = DerivationRegistry() -"""Global instance of the DerivationRegistry. - -Register custom derivation functions here to make them available by key -in `DeriveTagRule` configuration. -""" - - -def get_derivation( - key: str, - import_derivation: bool = False, - registry: Optional[DerivationRegistry] = None, -): - """Retrieve a derivation function by key, optionally importing it. - - First attempts to find the function in the provided `registry`. - If not found and `import_derivation` is True, attempts to dynamically - import the function using the `key` as a full Python path - (e.g., 'my_module.submodule.my_func'). - - Parameters - ---------- - key : str - The key or Python path of the derivation function. - import_derivation : bool, default=False - If True, attempt dynamic import if key is not in the registry. - registry : DerivationRegistry, optional - The registry instance to check first. Defaults to the global - `derivation_registry`. - - Returns - ------- - Derivation - The requested derivation function. - - Raises - ------ - KeyError - If the key is not found in the registry and either - `import_derivation` is False or the dynamic import fails. - ImportError - If dynamic import fails specifically due to module not found. - AttributeError - If dynamic import fails because the function name isn't in the module. - """ - registry = registry or default_derivation_registry - - if not import_derivation or key in registry: - return registry.get_derivation(key) - - try: - module_path, func_name = key.rsplit(".", 1) - module = importlib.import_module(module_path) - func = getattr(module, func_name) - return func - except ImportError as err: - raise KeyError( - f"Unable to load derivation '{key}'. Check the path and ensure " - "it points to a valid callable function in an importable module." - ) from err - - -TranformationRule = Annotated[ - Union[ReplaceRule, MapValueRule, DeriveTagRule], - Field(discriminator="rule_type"), -] - - -def build_transform_from_rule( - rule: TranformationRule, - derivation_registry: Optional[DerivationRegistry] = None, -) -> SoundEventTransformation: - """Build a specific SoundEventTransformation function from a rule config. - - Selects the appropriate transformation logic based on the rule's - `rule_type`, fetches necessary terms and derivation functions, and - returns a partially applied function ready to transform an annotation. - - Parameters - ---------- - rule : Union[ReplaceRule, MapValueRule, DeriveTagRule] - The configuration object for a single transformation rule. - registry : DerivationRegistry, optional - The derivation registry to use for `DeriveTagRule`. Defaults to the - global `derivation_registry`. - - Returns - ------- - SoundEventTransformation - A callable that applies the specified rule to a SoundEventAnnotation. - - Raises - ------ - KeyError - If required term keys or derivation keys are not found. - ValueError - If the rule has an unknown `rule_type`. - ImportError, AttributeError, TypeError - If dynamic import of a derivation function fails. - """ - if rule.rule_type == "replace": - source = get_tag_from_info(rule.original) - target = get_tag_from_info(rule.replacement) - return partial(replace_tag_transform, source=source, target=target) - - if rule.rule_type == "derive_tag": - source_term = terms.get_term(rule.source_term_key) - target_term = ( - terms.get_term(rule.target_term_key) - if rule.target_term_key - else source_term - ) - - if source_term is None or target_term is None: - raise KeyError("Terms not found") - - derivation = get_derivation( - key=rule.derivation_function, - import_derivation=rule.import_derivation, - registry=derivation_registry, - ) - return partial( - derivation_tag_transform, - source_term=source_term, - target_term=target_term, - derivation=derivation, - keep_source=rule.keep_source, - ) - - if rule.rule_type == "map_value": - source_term = terms.get_term(rule.source_term_key) - target_term = ( - terms.get_term(rule.target_term_key) - if rule.target_term_key - else source_term - ) - - if source_term is None or target_term is None: - raise KeyError("Terms not found") - - return partial( - map_value_transform, - source_term=source_term, - target_term=target_term, - mapping=rule.value_mapping, - ) - - # Handle unknown rule type - valid_options = ["replace", "derive_tag", "map_value"] - - # Should be caught by Pydantic validation, but good practice - raise ValueError( - f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. " - f"Valid options are: {valid_options}" - ) - - -def build_transformation_from_config( - config: TransformConfig, - derivation_registry: Optional[DerivationRegistry] = None, -) -> SoundEventTransformation: - """Build a composite transformation function from a TransformConfig. - - Creates a sequence of individual transformation functions based on the - rules defined in the configuration. Returns a single function that - applies these transformations sequentially to an annotation. - - Parameters - ---------- - config : TransformConfig - The configuration object containing the list of transformation rules. - derivation_reg : DerivationRegistry, optional - The derivation registry to use when building `DeriveTagRule` - transformations. Defaults to the global `derivation_registry`. - - Returns - ------- - SoundEventTransformation - A single function that applies all configured transformations in order. - """ - - transforms = [ - build_transform_from_rule( - rule, - derivation_registry=derivation_registry, - ) - for rule in config.rules - ] - - return partial(apply_sequence_of_transforms, transforms=transforms) - - -def apply_sequence_of_transforms( - sound_event_annotation: data.SoundEventAnnotation, - transforms: list[SoundEventTransformation], -) -> data.SoundEventAnnotation: - for transform in transforms: - sound_event_annotation = transform(sound_event_annotation) - return sound_event_annotation - - -def load_transformation_config( - path: data.PathLike, field: Optional[str] = None -) -> TransformConfig: - """Load the transformation configuration from a file. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (YAML). - field : str, optional - If the transformation configuration is nested under a specific key - in the file, specify the key here. Defaults to None. - - Returns - ------- - TransformConfig - The loaded and validated transformation configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the TransformConfig schema. - """ - return load_config(path=path, schema=TransformConfig, field=field) - - -def load_transformation_from_config( - path: data.PathLike, - field: Optional[str] = None, - derivation_registry: Optional[DerivationRegistry] = None, -) -> SoundEventTransformation: - """Load transformation config from a file and build the final function. - - This is a convenience function that combines loading the configuration - and building the final callable transformation function that applies - all rules sequentially. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (YAML). - field : str, optional - If the transformation configuration is nested under a specific key - in the file, specify the key here. Defaults to None. - - Returns - ------- - SoundEventTransformation - The final composite transformation function ready to be used. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the TransformConfig schema. - KeyError - If required term keys or derivation keys specified in the config - are not found during the build process. - ImportError, AttributeError, TypeError - If dynamic import of a derivation function specified in the config - fails. - """ - config = load_transformation_config(path=path, field=field) - return build_transformation_from_config( - config, - derivation_registry=derivation_registry, - ) - - -def register_derivation( - key: str, - derivation: Derivation, - derivation_registry: Optional[DerivationRegistry] = None, -) -> None: - """Register a new derivation function in the global registry. - - Parameters - ---------- - key : str - The unique key to associate with the derivation function. - derivation : Derivation - The callable derivation function (takes str, returns str). - derivation_registry : DerivationRegistry, optional - The registry instance to register the derivation function with. - Defaults to the global `derivation_registry`. - - Raises - ------ - KeyError - If a derivation function with the same key is already registered. - """ - derivation_registry = derivation_registry or default_derivation_registry - derivation_registry.register(key, derivation) diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 3bcdf93..5ec9437 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -44,7 +44,7 @@ AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]] class MixAugmentationConfig(BaseConfig): """Configuration for MixUp augmentation (mixing two examples).""" - augmentation_type: Literal["mix_audio"] = "mix_audio" + name: Literal["mix_audio"] = "mix_audio" probability: float = 0.2 """Probability of applying this augmentation to an example.""" @@ -140,7 +140,7 @@ def combine_clip_annotations( class EchoAugmentationConfig(BaseConfig): """Configuration for adding synthetic echo/reverb.""" - augmentation_type: Literal["add_echo"] = "add_echo" + name: Literal["add_echo"] = "add_echo" probability: float = 0.2 max_delay: float = 0.005 min_weight: float = 0.0 @@ -187,7 +187,7 @@ def add_echo( class VolumeAugmentationConfig(BaseConfig): """Configuration for random volume scaling of the spectrogram.""" - augmentation_type: Literal["scale_volume"] = "scale_volume" + name: Literal["scale_volume"] = "scale_volume" probability: float = 0.2 min_scaling: float = 0.0 max_scaling: float = 2.0 @@ -214,7 +214,7 @@ def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor: class WarpAugmentationConfig(BaseConfig): - augmentation_type: Literal["warp"] = "warp" + name: Literal["warp"] = "warp" probability: float = 0.2 delta: float = 0.04 @@ -296,7 +296,7 @@ def warp_spectrogram( class TimeMaskAugmentationConfig(BaseConfig): - augmentation_type: Literal["mask_time"] = "mask_time" + name: Literal["mask_time"] = "mask_time" probability: float = 0.2 max_perc: float = 0.05 max_masks: int = 3 @@ -353,7 +353,7 @@ def mask_time( class FrequencyMaskAugmentationConfig(BaseConfig): - augmentation_type: Literal["mask_freq"] = "mask_freq" + name: Literal["mask_freq"] = "mask_freq" probability: float = 0.2 max_perc: float = 0.10 max_masks: int = 3 @@ -414,7 +414,7 @@ AudioAugmentationConfig = Annotated[ MixAugmentationConfig, EchoAugmentationConfig, ], - Field(discriminator="augmentation_type"), + Field(discriminator="name"), ] @@ -425,7 +425,7 @@ SpectrogramAugmentationConfig = Annotated[ FrequencyMaskAugmentationConfig, TimeMaskAugmentationConfig, ], - Field(discriminator="augmentation_type"), + Field(discriminator="name"), ] AugmentationConfig = Annotated[ @@ -437,7 +437,7 @@ AugmentationConfig = Annotated[ FrequencyMaskAugmentationConfig, TimeMaskAugmentationConfig, ], - Field(discriminator="augmentation_type"), + Field(discriminator="name"), ] """Type alias for the discriminated union of individual augmentation config.""" @@ -485,7 +485,7 @@ def build_augmentation_from_config( audio_source: Optional[AudioSource] = None, ) -> Optional[Augmentation]: """Factory function to build a single augmentation from its config.""" - if config.augmentation_type == "mix_audio": + if config.name == "mix_audio": if audio_source is None: warnings.warn( "Mix audio augmentation ('mix_audio') requires an " @@ -500,31 +500,31 @@ def build_augmentation_from_config( max_weight=config.max_weight, ) - if config.augmentation_type == "add_echo": + if config.name == "add_echo": return AddEcho( max_delay=int(config.max_delay * samplerate), min_weight=config.min_weight, max_weight=config.max_weight, ) - if config.augmentation_type == "scale_volume": + if config.name == "scale_volume": return ScaleVolume( max_scaling=config.max_scaling, min_scaling=config.min_scaling, ) - if config.augmentation_type == "warp": + if config.name == "warp": return WarpSpectrogram( delta=config.delta, ) - if config.augmentation_type == "mask_time": + if config.name == "mask_time": return MaskTime( max_perc=config.max_perc, max_masks=config.max_masks, ) - if config.augmentation_type == "mask_freq": + if config.name == "mask_freq": return MaskFrequency( max_perc=config.max_perc, max_masks=config.max_masks, diff --git a/src/batdetect2/typing/targets.py b/src/batdetect2/typing/targets.py index 2846a0e..22b4780 100644 --- a/src/batdetect2/typing/targets.py +++ b/src/batdetect2/typing/targets.py @@ -67,8 +67,7 @@ class TargetProtocol(Protocol): This protocol outlines the standard attributes and methods for an object that encapsulates the complete, configured process for handling sound event annotations (both tags and geometry). It defines how to: - - Filter relevant annotations. - - Transform annotation tags. + - Select relevant annotations. - Encode an annotation into a specific target class name. - Decode a class name back into representative tags. - Extract a target reference position from an annotation's geometry (ROI). @@ -121,26 +120,6 @@ class TargetProtocol(Protocol): """ ... - def transform( - self, - sound_event: data.SoundEventAnnotation, - ) -> data.SoundEventAnnotation: - """Apply tag transformations to an annotation. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation whose tags should be transformed. - - Returns - ------- - data.SoundEventAnnotation - A new annotation object with the transformed tags. Implementations - should return the original annotation object if no transformations - were configured. - """ - ... - def encode_class( self, sound_event: data.SoundEventAnnotation, diff --git a/tests/conftest.py b/tests/conftest.py index 36c5c9a..5f5e32b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,9 +18,7 @@ from batdetect2.targets import ( build_targets, call_type, ) -from batdetect2.targets.classes import ClassesConfig, TargetClass -from batdetect2.targets.filtering import FilterConfig, FilterRule -from batdetect2.targets.terms import TagInfo +from batdetect2.targets.classes import TargetClassConfig from batdetect2.train.clips import build_clipper from batdetect2.train.labels import build_clip_labeler from batdetect2.typing import ( @@ -365,43 +363,37 @@ def sample_audio_loader() -> AudioLoader: @pytest.fixture -def bat_tag() -> TagInfo: - return TagInfo(key="class", value="bat") +def bat_tag() -> data.Tag: + return data.Tag(key="class", value="bat") @pytest.fixture -def noise_tag() -> TagInfo: - return TagInfo(key="class", value="noise") +def noise_tag() -> data.Tag: + return data.Tag(key="class", value="noise") @pytest.fixture -def myomyo_tag() -> TagInfo: - return TagInfo(key="species", value="Myotis myotis") +def myomyo_tag() -> data.Tag: + return data.Tag(key="species", value="Myotis myotis") @pytest.fixture -def pippip_tag() -> TagInfo: - return TagInfo(key="species", value="Pipistrellus pipistrellus") +def pippip_tag() -> data.Tag: + return data.Tag(key="species", value="Pipistrellus pipistrellus") @pytest.fixture def sample_target_config( - bat_tag: TagInfo, - noise_tag: TagInfo, - myomyo_tag: TagInfo, - pippip_tag: TagInfo, + bat_tag: data.Tag, + myomyo_tag: data.Tag, + pippip_tag: data.Tag, ) -> TargetConfig: return TargetConfig( - filtering=FilterConfig( - rules=[FilterRule(match_type="exclude", tags=[noise_tag])] - ), - classes=ClassesConfig( - classes=[ - TargetClass(name="pippip", tags=[pippip_tag]), - TargetClass(name="myomyo", tags=[myomyo_tag]), - ], - generic_class=[bat_tag], - ), + detection_target=TargetClassConfig(name="bat", tags=[bat_tag]), + classification_targets=[ + TargetClassConfig(name="pippip", tags=[pippip_tag]), + TargetClassConfig(name="myomyo", tags=[myomyo_tag]), + ], ) diff --git a/tests/test_data/test_transforms/__init__.py b/tests/test_data/test_transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_data/test_transforms/test_conditions.py b/tests/test_data/test_transforms/test_conditions.py new file mode 100644 index 0000000..6d5068e --- /dev/null +++ b/tests/test_data/test_transforms/test_conditions.py @@ -0,0 +1,516 @@ +import textwrap + +import pytest +import yaml +from pydantic import TypeAdapter +from soundevent import data + +from batdetect2.data.conditions import ( + SoundEventConditionConfig, + build_sound_event_condition, +) + + +def build_condition_from_str(content): + content = textwrap.dedent(content) + content = yaml.safe_load(content) + config = TypeAdapter(SoundEventConditionConfig).validate_python(content) + return build_sound_event_condition(config) + + +def test_has_tag(sound_event: data.SoundEvent): + condition = build_condition_from_str(""" + name: has_tag + tag: + key: species + value: Myotis myotis + """) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + ) + assert not condition(sound_event_annotation) + + +def test_has_all_tags(sound_event: data.SoundEvent): + condition = build_condition_from_str(""" + name: has_all_tags + tags: + - key: species + value: Myotis myotis + - key: event + value: Echolocation + """) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert not condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore + data.Tag(key="event", value="Echolocation"), # type: ignore + ], + ) + assert not condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag(key="species", value="Myotis myotis"), # type: ignore + data.Tag(key="event", value="Echolocation"), # type: ignore + ], + ) + assert condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag(key="species", value="Myotis myotis"), # type: ignore + data.Tag(key="event", value="Echolocation"), # type: ignore + data.Tag(key="sex", value="Female"), # type: ignore + ], + ) + assert condition(sound_event_annotation) + + +def test_has_any_tags(sound_event: data.SoundEvent): + condition = build_condition_from_str(""" + name: has_any_tag + tags: + - key: species + value: Myotis myotis + - key: event + value: Echolocation + """) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore + data.Tag(key="event", value="Echolocation"), # type: ignore + ], + ) + assert condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag(key="species", value="Myotis myotis"), # type: ignore + data.Tag(key="event", value="Echolocation"), # type: ignore + ], + ) + assert condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore + data.Tag(key="event", value="Social"), # type: ignore + ], + ) + assert not condition(sound_event_annotation) + + +def test_not(sound_event: data.SoundEvent): + condition = build_condition_from_str(""" + name: not + condition: + name: has_tag + tag: + key: species + value: Myotis myotis + """) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert not condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + ) + assert condition(sound_event_annotation) + + sound_event_annotation = data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag(key="species", value="Myotis myotis"), # type: ignore + data.Tag(key="event", value="Echolocation"), # type: ignore + ], + ) + assert not condition(sound_event_annotation) + + +def test_duration(recording: data.Recording): + se1 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, geometry=data.TimeInterval(coordinates=[0, 1]) + ), + ) + se2 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, geometry=data.TimeInterval(coordinates=[0, 2]) + ), + ) + se3 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, geometry=data.TimeInterval(coordinates=[0, 3]) + ), + ) + + condition = build_condition_from_str(""" + name: duration + operator: lt + seconds: 2 + """) + assert condition(se1) + assert not condition(se2) + assert not condition(se3) + + condition = build_condition_from_str(""" + name: duration + operator: lte + seconds: 2 + """) + + assert condition(se1) + assert condition(se2) + assert not condition(se3) + + condition = build_condition_from_str(""" + name: duration + operator: gt + seconds: 2 + """) + + assert not condition(se1) + assert not condition(se2) + assert condition(se3) + + condition = build_condition_from_str(""" + name: duration + operator: gte + seconds: 2 + """) + + assert not condition(se1) + assert condition(se2) + assert condition(se3) + + condition = build_condition_from_str(""" + name: duration + operator: eq + seconds: 2 + """) + + assert not condition(se1) + assert condition(se2) + assert not condition(se3) + + +def test_frequency(recording: data.Recording): + se12 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, + geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]), + ), + ) + se13 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, + geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]), + ), + ) + se14 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, + geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]), + ), + ) + se24 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, + geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]), + ), + ) + se34 = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, + geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]), + ), + ) + + condition = build_condition_from_str(""" + name: frequency + boundary: high + operator: lt + hertz: 300 + """) + assert condition(se12) + assert not condition(se13) + assert not condition(se14) + + condition = build_condition_from_str(""" + name: frequency + boundary: high + operator: lte + hertz: 300 + """) + + assert condition(se12) + assert condition(se13) + assert not condition(se14) + + condition = build_condition_from_str(""" + name: frequency + boundary: high + operator: gt + hertz: 300 + """) + + assert not condition(se12) + assert not condition(se13) + assert condition(se14) + + condition = build_condition_from_str(""" + name: frequency + boundary: high + operator: gte + hertz: 300 + """) + + assert not condition(se12) + assert condition(se13) + assert condition(se14) + + condition = build_condition_from_str(""" + name: frequency + boundary: high + operator: eq + hertz: 300 + """) + + assert not condition(se12) + assert condition(se13) + assert not condition(se14) + + # LOW + + condition = build_condition_from_str(""" + name: frequency + boundary: low + operator: lt + hertz: 200 + """) + assert condition(se14) + assert not condition(se24) + assert not condition(se34) + + condition = build_condition_from_str(""" + name: frequency + boundary: low + operator: lte + hertz: 200 + """) + + assert condition(se14) + assert condition(se24) + assert not condition(se34) + + condition = build_condition_from_str(""" + name: frequency + boundary: low + operator: gt + hertz: 200 + """) + + assert not condition(se14) + assert not condition(se24) + assert condition(se34) + + condition = build_condition_from_str(""" + name: frequency + boundary: low + operator: gte + hertz: 200 + """) + + assert not condition(se14) + assert condition(se24) + assert condition(se34) + + condition = build_condition_from_str(""" + name: frequency + boundary: low + operator: eq + hertz: 200 + """) + + assert not condition(se14) + assert condition(se24) + assert not condition(se34) + + +def test_frequency_is_false_for_temporal_geometries(recording: data.Recording): + condition = build_condition_from_str(""" + name: frequency + boundary: low + operator: eq + hertz: 200 + """) + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 3]), + recording=recording, + ) + ) + assert not condition(se) + + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeStamp(coordinates=3), + recording=recording, + ) + ) + assert not condition(se) + + +def test_has_tags_fails_if_empty(): + with pytest.raises(ValueError): + build_condition_from_str(""" + name: has_tags + tags: [] + """) + + +def test_frequency_is_false_if_no_geometry(recording: data.Recording): + condition = build_condition_from_str(""" + name: frequency + boundary: low + operator: eq + hertz: 200 + """) + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent(geometry=None, recording=recording) + ) + assert not condition(se) + + +def test_duration_is_false_if_no_geometry(recording: data.Recording): + condition = build_condition_from_str(""" + name: duration + operator: eq + seconds: 1 + """) + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent(geometry=None, recording=recording) + ) + assert not condition(se) + + +def test_all_of(recording: data.Recording): + condition = build_condition_from_str(""" + name: all_of + conditions: + - name: has_tag + tag: + key: species + value: Myotis myotis + - name: duration + operator: lt + seconds: 1 + """) + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 0.5]), + recording=recording, + ), + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert condition(se) + + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 2]), + recording=recording, + ), + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert not condition(se) + + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 0.5]), + recording=recording, + ), + tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + ) + assert not condition(se) + + +def test_any_of(recording: data.Recording): + condition = build_condition_from_str(""" + name: any_of + conditions: + - name: has_tag + tag: + key: species + value: Myotis myotis + - name: duration + operator: lt + seconds: 1 + """) + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 2]), + recording=recording, + ), + tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + ) + assert not condition(se) + + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 0.5]), + recording=recording, + ), + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert condition(se) + + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 2]), + recording=recording, + ), + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore + ) + assert condition(se) + + se = data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.TimeInterval(coordinates=[0, 0.5]), + recording=recording, + ), + tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore + ) + assert condition(se) diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index 587b502..c80c5ec 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -3,26 +3,15 @@ from typing import Callable from uuid import uuid4 import pytest -from pydantic import ValidationError from soundevent import data from soundevent.terms import get_term from batdetect2.targets.classes import ( - DEFAULT_SPECIES_LIST, - ClassesConfig, - TargetClass, - _get_default_class_name, - _get_default_classes, - build_generic_class_tags, + TargetClassConfig, build_sound_event_decoder, build_sound_event_encoder, get_class_names_from_config, - is_target_class, - load_classes_config, - load_decoder_from_config, - load_encoder_from_config, ) -from batdetect2.targets.terms import TagInfo @pytest.fixture @@ -33,8 +22,8 @@ def sample_annotation( return data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore - data.Tag(key="quality", value="Good"), # type: ignore + data.Tag(key="species", value="Pipistrellus pipistrellus"), + data.Tag(key="quality", value="Good"), ], ) @@ -51,291 +40,71 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]: return factory -def test_target_class_creation(): - target_class = TargetClass( - name="pippip", - tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], - ) - assert target_class.name == "pippip" - assert target_class.tags[0].key == "species" - assert target_class.tags[0].value == "Pipistrellus pipistrellus" - assert target_class.match_type == "all" - - -def test_classes_config_creation(): - target_class = TargetClass( - name="pippip", - tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], - ) - config = ClassesConfig(classes=[target_class]) - assert len(config.classes) == 1 - assert config.classes[0].name == "pippip" - - -def test_classes_config_unique_names(): - target_class1 = TargetClass( - name="pippip", - tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], - ) - target_class2 = TargetClass( - name="myodau", - tags=[TagInfo(key="species", value="Myotis daubentonii")], - ) - ClassesConfig(classes=[target_class1, target_class2]) # No error - - -def test_classes_config_non_unique_names(): - target_class1 = TargetClass( - name="pippip", - tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], - ) - target_class2 = TargetClass( - name="pippip", - tags=[TagInfo(key="species", value="Myotis daubentonii")], - ) - with pytest.raises(ValidationError): - ClassesConfig(classes=[target_class1, target_class2]) - - -def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]): - yaml_content = """ - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - """ - temp_yaml_path = create_temp_yaml(yaml_content) - config = load_classes_config(temp_yaml_path) - assert len(config.classes) == 1 - assert config.classes[0].name == "pippip" - - -def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]): - yaml_content = """ - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - - name: pippip - tags: - - key: species - value: Myotis daubentonii - """ - temp_yaml_path = create_temp_yaml(yaml_content) - with pytest.raises(ValidationError): - load_classes_config(temp_yaml_path) - - -def test_is_target_class_match_all( - sample_annotation: data.SoundEventAnnotation, -): - tags = { - data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore - data.Tag(key="quality", value="Good"), # type: ignore - } - assert is_target_class(sample_annotation, tags, match_all=True) is True - - tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore - assert is_target_class(sample_annotation, tags, match_all=True) is True - - tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore - assert is_target_class(sample_annotation, tags, match_all=True) is False - - -def test_is_target_class_match_any( - sample_annotation: data.SoundEventAnnotation, -): - tags = { - data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore - data.Tag(key="quality", value="Good"), # type: ignore - } - assert is_target_class(sample_annotation, tags, match_all=False) is True - - tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore - assert is_target_class(sample_annotation, tags, match_all=False) is True - - tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore - assert is_target_class(sample_annotation, tags, match_all=False) is False - - def test_get_class_names_from_config(): - target_class1 = TargetClass( + target_class1 = TargetClassConfig( name="pippip", - tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], + tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")], ) - target_class2 = TargetClass( + target_class2 = TargetClassConfig( name="myodau", - tags=[TagInfo(key="species", value="Myotis daubentonii")], + tags=[data.Tag(key="species", value="Myotis daubentonii")], ) - config = ClassesConfig(classes=[target_class1, target_class2]) - names = get_class_names_from_config(config) + names = get_class_names_from_config([target_class1, target_class2]) assert names == ["pippip", "myodau"] def test_build_encoder_from_config( sample_annotation: data.SoundEventAnnotation, ): - config = ClassesConfig( - classes=[ - TargetClass( - name="pippip", - tags=[ - TagInfo(key="species", value="Pipistrellus pipistrellus") - ], - ) - ] - ) - encoder = build_sound_event_encoder(config) + classes = [ + TargetClassConfig( + name="pippip", + tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")], + ) + ] + encoder = build_sound_event_encoder(classes) result = encoder(sample_annotation) assert result == "pippip" - config = ClassesConfig(classes=[]) - encoder = build_sound_event_encoder(config) + classes = [] + encoder = build_sound_event_encoder(classes) result = encoder(sample_annotation) assert result is None -def test_load_encoder_from_config_valid( - sample_annotation: data.SoundEventAnnotation, - create_temp_yaml: Callable[[str], Path], -): - yaml_content = """ - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - """ - temp_yaml_path = create_temp_yaml(yaml_content) - encoder = load_encoder_from_config(temp_yaml_path) - # We cannot directly compare the function, so we test it. - result = encoder(sample_annotation) # type: ignore - assert result == "pippip" - - -def test_load_encoder_from_config_invalid( - create_temp_yaml: Callable[[str], Path], -): - yaml_content = """ - classes: - - name: pippip - tags: - - key: invalid_key - value: Pipistrellus pipistrellus - """ - temp_yaml_path = create_temp_yaml(yaml_content) - with pytest.raises(KeyError): - load_encoder_from_config(temp_yaml_path) - - -def test_get_default_class_name(): - assert _get_default_class_name("Myotis daubentonii") == "myodau" - - -def test_get_default_classes(): - default_classes = _get_default_classes() - assert len(default_classes) == len(DEFAULT_SPECIES_LIST) - first_class = default_classes[0] - assert isinstance(first_class, TargetClass) - assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0]) - assert first_class.tags[0].key == "class" - assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0] - - def test_build_decoder_from_config(): - config = ClassesConfig( - classes=[ - TargetClass( - name="pippip", - tags=[ - TagInfo(key="species", value="Pipistrellus pipistrellus") - ], - output_tags=[TagInfo(key="call_type", value="Echolocation")], - ) - ], - generic_class=[TagInfo(key="order", value="Chiroptera")], - ) - decoder = build_sound_event_decoder(config) + classes = [ + TargetClassConfig( + name="pippip", + tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")], + assign_tags=[data.Tag(key="call_type", value="Echolocation")], + ) + ] + decoder = build_sound_event_decoder(classes) tags = decoder("pippip") assert len(tags) == 1 assert tags[0].term == get_term("event") assert tags[0].value == "Echolocation" # Test when output_tags is None, should fall back to tags - config = ClassesConfig( - classes=[ - TargetClass( - name="pippip", - tags=[ - TagInfo(key="species", value="Pipistrellus pipistrellus") - ], - ) - ], - generic_class=[TagInfo(key="order", value="Chiroptera")], - ) - decoder = build_sound_event_decoder(config) + classes = [ + TargetClassConfig( + name="pippip", + tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")], + ) + ] + decoder = build_sound_event_decoder(classes) tags = decoder("pippip") assert len(tags) == 1 assert tags[0].term == get_term("species") assert tags[0].value == "Pipistrellus pipistrellus" # Test raise_on_unmapped=True - decoder = build_sound_event_decoder(config, raise_on_unmapped=True) + decoder = build_sound_event_decoder(classes, raise_on_unmapped=True) with pytest.raises(ValueError): decoder("unknown_class") # Test raise_on_unmapped=False - decoder = build_sound_event_decoder(config, raise_on_unmapped=False) + decoder = build_sound_event_decoder(classes, raise_on_unmapped=False) tags = decoder("unknown_class") assert len(tags) == 0 - - -def test_load_decoder_from_config_valid( - create_temp_yaml: Callable[[str], Path], -): - yaml_content = """ - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - output_tags: - - key: call_type - value: Echolocation - generic_class: - - key: order - value: Chiroptera - """ - temp_yaml_path = create_temp_yaml(yaml_content) - decoder = load_decoder_from_config( - temp_yaml_path, - ) - tags = decoder("pippip") - assert len(tags) == 1 - assert tags[0].term == get_term("call_type") - assert tags[0].value == "Echolocation" - - -def test_build_generic_class_tags_from_config(): - config = ClassesConfig( - classes=[ - TargetClass( - name="pippip", - tags=[ - TagInfo(key="species", value="Pipistrellus pipistrellus") - ], - ) - ], - generic_class=[ - TagInfo(key="order", value="Chiroptera"), - TagInfo(key="call_type", value="Echolocation"), - ], - ) - generic_tags = build_generic_class_tags(config) - assert len(generic_tags) == 2 - assert generic_tags[0].term == get_term("order") - assert generic_tags[0].value == "Chiroptera" - assert generic_tags[1].term == get_term("call_type") - assert generic_tags[1].value == "Echolocation" diff --git a/tests/test_targets/test_filtering.py b/tests/test_targets/test_filtering.py deleted file mode 100644 index 3036143..0000000 --- a/tests/test_targets/test_filtering.py +++ /dev/null @@ -1,210 +0,0 @@ -from pathlib import Path -from typing import Callable, List, Set - -import pytest -from soundevent import data - -from batdetect2.targets import build_targets -from batdetect2.targets.filtering import ( - FilterConfig, - FilterRule, - build_filter_from_rule, - build_sound_event_filter, - contains_tags, - does_not_have_tags, - equal_tags, - has_any_tag, - load_filter_config, - load_filter_from_config, -) -from batdetect2.targets.terms import TagInfo, generic_class - - -@pytest.fixture -def create_annotation( - sound_event: data.SoundEvent, -) -> Callable[[List[str]], data.SoundEventAnnotation]: - """Helper function to create a SoundEventAnnotation with given tags.""" - - def factory(tags: List[str]) -> data.SoundEventAnnotation: - return data.SoundEventAnnotation( - sound_event=sound_event, - tags=[ - data.Tag( - term=generic_class, - value=tag, - ) - for tag in tags - ], - ) - - return factory - - -def create_tag_set(tags: List[str]) -> Set[data.Tag]: - """Helper function to create a set of data.Tag objects from a list of strings.""" - return { - data.Tag( - term=generic_class, - value=tag, - ) - for tag in tags - } - - -def test_has_any_tag(create_annotation): - annotation = create_annotation(["tag1", "tag2"]) - tags = create_tag_set(["tag1", "tag3"]) - assert has_any_tag(annotation, tags) is True - - annotation = create_annotation(["tag2", "tag4"]) - tags = create_tag_set(["tag1", "tag3"]) - assert has_any_tag(annotation, tags) is False - - -def test_contains_tags(create_annotation): - annotation = create_annotation(["tag1", "tag2", "tag3"]) - tags = create_tag_set(["tag1", "tag2"]) - assert contains_tags(annotation, tags) is True - - annotation = create_annotation(["tag1", "tag2"]) - tags = create_tag_set(["tag1", "tag2", "tag3"]) - assert contains_tags(annotation, tags) is False - - -def test_does_not_have_tags(create_annotation): - annotation = create_annotation(["tag1", "tag2"]) - tags = create_tag_set(["tag3", "tag4"]) - assert does_not_have_tags(annotation, tags) is True - - annotation = create_annotation(["tag1", "tag2"]) - tags = create_tag_set(["tag1", "tag3"]) - assert does_not_have_tags(annotation, tags) is False - - -def test_equal_tags(create_annotation): - annotation = create_annotation(["tag1", "tag2"]) - tags = create_tag_set(["tag1", "tag2"]) - assert equal_tags(annotation, tags) is True - - annotation = create_annotation(["tag1", "tag2", "tag3"]) - tags = create_tag_set(["tag1", "tag2"]) - assert equal_tags(annotation, tags) is False - - -def test_build_filter_from_rule(): - rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")]) - build_filter_from_rule(rule_any) - - rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")]) - build_filter_from_rule(rule_all) - - rule_exclude = FilterRule( - match_type="exclude", tags=[TagInfo(value="tag1")] - ) - build_filter_from_rule(rule_exclude) - - rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")]) - build_filter_from_rule(rule_equal) - - with pytest.raises(ValueError): - FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore - build_filter_from_rule( - FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore - ) - - -def test_build_filter_from_config(create_annotation): - config = FilterConfig( - rules=[ - FilterRule(match_type="any", tags=[TagInfo(value="tag1")]), - FilterRule(match_type="any", tags=[TagInfo(value="tag2")]), - ] - ) - filter_from_config = build_sound_event_filter(config) - - annotation_pass = create_annotation(["tag1", "tag2"]) - assert filter_from_config(annotation_pass) - - annotation_fail = create_annotation(["tag1"]) - assert not filter_from_config(annotation_fail) - - -def test_load_filter_config(tmp_path: Path): - test_config_path = tmp_path / "filtering.yaml" - test_config_path.write_text( - """ -rules: - - match_type: any - tags: - - value: tag1 - """ - ) - config = load_filter_config(test_config_path) - assert isinstance(config, FilterConfig) - assert len(config.rules) == 1 - rule = config.rules[0] - assert rule.match_type == "any" - assert len(rule.tags) == 1 - assert rule.tags[0].value == "tag1" - - -def test_load_filter_from_config(tmp_path: Path, create_annotation): - test_config_path = tmp_path / "filtering.yaml" - test_config_path.write_text( - """ -rules: - - match_type: any - tags: - - value: tag1 - """ - ) - - filter_result = load_filter_from_config(test_config_path) - annotation = create_annotation(["tag1", "tag3"]) - assert filter_result(annotation) - - test_config_path = tmp_path / "filtering.yaml" - test_config_path.write_text( - """ -rules: - - match_type: any - tags: - - value: tag2 - """ - ) - - filter_result = load_filter_from_config(test_config_path) - annotation = create_annotation(["tag1", "tag3"]) - assert filter_result(annotation) is False - - -def test_default_filtering_over_example_dataset( - example_annotations: List[data.ClipAnnotation], -): - targets = build_targets() - - clip1 = example_annotations[0] - clip2 = example_annotations[1] - clip3 = example_annotations[2] - - assert ( - sum( - [targets.filter(sound_event) for sound_event in clip1.sound_events] - ) - == 9 - ) - - assert ( - sum( - [targets.filter(sound_event) for sound_event in clip2.sound_events] - ) - == 15 - ) - - assert ( - sum( - [targets.filter(sound_event) for sound_event in clip3.sound_events] - ) - == 20 - ) diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index 8abf18f..66e307a 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -8,6 +8,11 @@ from batdetect2.preprocess import ( build_preprocessor, ) from batdetect2.preprocess.audio import build_audio_loader +from batdetect2.preprocess.spectrogram import ( + ScaleAmplitudeConfig, + SpectralMeanSubstractionConfig, + SpectrogramConfig, +) from batdetect2.targets.rois import ( DEFAULT_ANCHOR, DEFAULT_FREQUENCY_SCALE, @@ -548,14 +553,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle): # Instantiate the mapper. preprocessor = build_preprocessor( - PreprocessingConfig.model_validate( - { - "spectrogram": { - "pcen": None, - "spectral_mean_substraction": False, - } - } - ) + PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[])) ) audio_loader = build_audio_loader() mapper = PeakEnergyBBoxMapper( @@ -597,14 +595,13 @@ def test_build_roi_mapper_for_anchor_bbox(): def test_build_roi_mapper_for_peak_energy_bbox(): # Given - preproc_config = PreprocessingConfig.model_validate( - { - "spectrogram": { - "pcen": None, - "spectral_mean_substraction": True, - "scale": "dB", - } - } + preproc_config = PreprocessingConfig( + spectrogram=SpectrogramConfig( + transforms=[ + ScaleAmplitudeConfig(scale="db"), + SpectralMeanSubstractionConfig(), + ] + ), ) config = PeakEnergyBBoxMapperConfig( loading_buffer=0.99, diff --git a/tests/test_targets/test_targets.py b/tests/test_targets/test_targets.py index 8324807..fad195b 100644 --- a/tests/test_targets/test_targets.py +++ b/tests/test_targets/test_targets.py @@ -11,25 +11,34 @@ def test_can_override_default_roi_mapper_per_class( recording: data.Recording, ): yaml_content = """ + detection_target: + name: bat + match_if: + name: has_tag + tag: + key: order + value: Chiroptera + assign_tags: + - key: order + value: Chiroptera + + classification_targets: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + + - name: myomyo + tags: + - key: species + value: Myotis myotis + roi: + name: anchor_bbox + anchor: top-left + roi: name: anchor_bbox anchor: bottom-left - classes: - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - - name: myomyo - tags: - - key: species - value: Myotis myotis - roi: - name: anchor_bbox - anchor: top-left - generic_class: - - key: order - value: Chiroptera """ config_path = create_temp_yaml(yaml_content) @@ -65,25 +74,34 @@ def test_roi_is_recovered_roundtrip_even_with_overriders( recording, ): yaml_content = """ + detection_target: + name: bat + match_if: + name: has_tag + tag: + key: order + value: Chiroptera + assign_tags: + - key: order + value: Chiroptera + + classification_targets: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + + - name: myomyo + tags: + - key: species + value: Myotis myotis + roi: + name: anchor_bbox + anchor: top-left + roi: name: anchor_bbox anchor: bottom-left - classes: - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - - name: myomyo - tags: - - key: species - value: Myotis myotis - roi: - name: anchor_bbox - anchor: top-left - generic_class: - - key: order - value: Chiroptera """ config_path = create_temp_yaml(yaml_content) diff --git a/tests/test_targets/test_terms.py b/tests/test_targets/test_terms.py deleted file mode 100644 index 74fa927..0000000 --- a/tests/test_targets/test_terms.py +++ /dev/null @@ -1,17 +0,0 @@ -import pytest - -from batdetect2.targets import terms -from batdetect2.targets.terms import TagInfo - - -def test_tag_info_and_get_tag_from_info(): - tag_info = TagInfo(value="Myotis myotis", key="event") - tag = terms.get_tag_from_info(tag_info) - assert tag.value == "Myotis myotis" - assert tag.term == terms.call_type - - -def test_get_tag_from_info_key_not_found(): - tag_info = TagInfo(value="test", key="non_existent_key") - with pytest.raises(KeyError): - terms.get_tag_from_info(tag_info) diff --git a/tests/test_targets/test_transform.py b/tests/test_targets/test_transform.py deleted file mode 100644 index 92c6698..0000000 --- a/tests/test_targets/test_transform.py +++ /dev/null @@ -1,360 +0,0 @@ -from pathlib import Path - -import pytest -from soundevent import data, terms - -from batdetect2.targets import ( - DeriveTagRule, - MapValueRule, - ReplaceRule, - TagInfo, - TransformConfig, - build_transformation_from_config, -) -from batdetect2.targets.transform import ( - DerivationRegistry, - build_transform_from_rule, -) - - -@pytest.fixture -def derivation_registry(): - return DerivationRegistry() - - -@pytest.fixture -def term1() -> data.Term: - term = data.Term(label="Term 1", definition="unknown", name="test:term1") - terms.add_term(term, key="term1", force=True) - return term - - -@pytest.fixture -def term2() -> data.Term: - term = data.Term(label="Term 2", definition="unknown", name="test:term2") - terms.add_term(term, key="term2", force=True) - return term - - -@pytest.fixture -def term3() -> data.Term: - term = data.Term(label="Term 3", definition="unknown", name="test:term3") - terms.add_term(term, key="term3", force=True) - return term - - -@pytest.fixture -def annotation( - sound_event: data.SoundEvent, - term1: data.Term, -) -> data.SoundEventAnnotation: - return data.SoundEventAnnotation( - sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")] - ) - -@pytest.fixture -def annotation2( - sound_event: data.SoundEvent, - term2: data.Term, -) -> data.SoundEventAnnotation: - return data.SoundEventAnnotation( - sound_event=sound_event, tags=[data.Tag(term=term2, value="value2")] - ) - - -def test_map_value_rule(annotation: data.SoundEventAnnotation): - rule = MapValueRule( - rule_type="map_value", - source_term_key="term1", - value_mapping={"value1": "value2"}, - ) - transform_fn = build_transform_from_rule(rule) - transformed_annotation = transform_fn(annotation) - assert transformed_annotation.tags[0].value == "value2" - - -def test_map_value_rule_no_match(annotation: data.SoundEventAnnotation): - rule = MapValueRule( - rule_type="map_value", - source_term_key="term1", - value_mapping={"other_value": "value2"}, - ) - transform_fn = build_transform_from_rule(rule) - transformed_annotation = transform_fn(annotation) - assert transformed_annotation.tags[0].value == "value1" - - -def test_replace_rule(annotation: data.SoundEventAnnotation, term2: data.Term): - rule = ReplaceRule( - rule_type="replace", - original=TagInfo(key="term1", value="value1"), - replacement=TagInfo(key="term2", value="value2"), - ) - transform_fn = build_transform_from_rule(rule) - transformed_annotation = transform_fn(annotation) - assert transformed_annotation.tags[0].term == term2 - assert transformed_annotation.tags[0].value == "value2" - - -def test_replace_rule_no_match( - annotation: data.SoundEventAnnotation, - term1: data.Term, - term2: data.Term, -): - rule = ReplaceRule( - rule_type="replace", - original=TagInfo(key="term1", value="wrong_value"), - replacement=TagInfo(key="term2", value="value2"), - ) - transform_fn = build_transform_from_rule(rule) - transformed_annotation = transform_fn(annotation) - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].term != term2 - assert transformed_annotation.tags[0].value == "value1" - - -def test_build_transformation_from_config( - annotation: data.SoundEventAnnotation, - annotation2: data.SoundEventAnnotation, - term1: data.Term, - term2: data.Term, - term3: data.Term, -): - config = TransformConfig( - rules=[ - MapValueRule( - rule_type="map_value", - source_term_key="term1", - value_mapping={"value1": "value2"}, - ), - ReplaceRule( - rule_type="replace", - original=TagInfo(key="term2", value="value2"), - replacement=TagInfo(key="term3", value="value3"), - ), - ] - ) - transform = build_transformation_from_config(config) - - transformed_annotation = transform(annotation) - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].term != term2 - assert transformed_annotation.tags[0].value == "value2" - - transformed_annotation = transform(annotation2) - assert transformed_annotation.tags[0].term == term3 - assert transformed_annotation.tags[0].value == "value3" - - -def test_derive_tag_rule( - annotation: data.SoundEventAnnotation, - derivation_registry: DerivationRegistry, - term1: data.Term, -): - def derivation_func(x: str) -> str: - return x + "_derived" - - derivation_registry.register("my_derivation", derivation_func) - - rule = DeriveTagRule( - rule_type="derive_tag", - source_term_key="term1", - derivation_function="my_derivation", - ) - transform_fn = build_transform_from_rule( - rule, - derivation_registry=derivation_registry, - ) - transformed_annotation = transform_fn(annotation) - - assert len(transformed_annotation.tags) == 2 - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].value == "value1" - assert transformed_annotation.tags[1].term == term1 - assert transformed_annotation.tags[1].value == "value1_derived" - - -def test_derive_tag_rule_keep_source_false( - annotation: data.SoundEventAnnotation, - derivation_registry: DerivationRegistry, - term1: data.Term, -): - def derivation_func(x: str) -> str: - return x + "_derived" - - derivation_registry.register("my_derivation", derivation_func) - - rule = DeriveTagRule( - rule_type="derive_tag", - source_term_key="term1", - derivation_function="my_derivation", - keep_source=False, - ) - transform_fn = build_transform_from_rule( - rule, - derivation_registry=derivation_registry, - ) - transformed_annotation = transform_fn(annotation) - - assert len(transformed_annotation.tags) == 1 - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].value == "value1_derived" - - -def test_derive_tag_rule_target_term( - annotation: data.SoundEventAnnotation, - derivation_registry: DerivationRegistry, - term1: data.Term, - term2: data.Term, -): - def derivation_func(x: str) -> str: - return x + "_derived" - - derivation_registry.register("my_derivation", derivation_func) - - rule = DeriveTagRule( - rule_type="derive_tag", - source_term_key="term1", - derivation_function="my_derivation", - target_term_key="term2", - ) - transform_fn = build_transform_from_rule( - rule, - derivation_registry=derivation_registry, - ) - transformed_annotation = transform_fn(annotation) - - assert len(transformed_annotation.tags) == 2 - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].value == "value1" - assert transformed_annotation.tags[1].term == term2 - assert transformed_annotation.tags[1].value == "value1_derived" - - -def test_derive_tag_rule_import_derivation( - annotation: data.SoundEventAnnotation, - term1: data.Term, - tmp_path: Path, -): - # Create a dummy derivation function in a temporary file - derivation_module_path = ( - tmp_path / "temp_derivation.py" - ) # Changed to /tmp since /home/santiago is not writable - derivation_module_path.write_text( - """ -def my_imported_derivation(x: str) -> str: - return x + "_imported" -""" - ) - # Ensure the temporary file is importable by adding its directory to sys.path - import sys - - sys.path.insert(0, str(tmp_path)) - - rule = DeriveTagRule( - rule_type="derive_tag", - source_term_key="term1", - derivation_function="temp_derivation.my_imported_derivation", - import_derivation=True, - ) - transform_fn = build_transform_from_rule(rule) - transformed_annotation = transform_fn(annotation) - - assert len(transformed_annotation.tags) == 2 - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].value == "value1" - assert transformed_annotation.tags[1].term == term1 - assert transformed_annotation.tags[1].value == "value1_imported" - - # Clean up the temporary file and sys.path - sys.path.remove(str(tmp_path)) - - -def test_derive_tag_rule_invalid_derivation(): - rule = DeriveTagRule( - rule_type="derive_tag", - source_term_key="term1", - derivation_function="nonexistent_derivation", - ) - with pytest.raises(KeyError): - build_transform_from_rule(rule) - - -def test_build_transform_from_rule_invalid_rule_type(): - class InvalidRule: - rule_type = "invalid" - - rule = InvalidRule() # type: ignore - - with pytest.raises(ValueError): - build_transform_from_rule(rule) # type: ignore - - -def test_map_value_rule_target_term( - annotation: data.SoundEventAnnotation, - term2: data.Term, -): - rule = MapValueRule( - rule_type="map_value", - source_term_key="term1", - value_mapping={"value1": "value2"}, - target_term_key="term2", - ) - transform_fn = build_transform_from_rule(rule) - transformed_annotation = transform_fn(annotation) - assert transformed_annotation.tags[0].term == term2 - assert transformed_annotation.tags[0].value == "value2" - - -def test_map_value_rule_target_term_none( - annotation: data.SoundEventAnnotation, - term1: data.Term, -): - rule = MapValueRule( - rule_type="map_value", - source_term_key="term1", - value_mapping={"value1": "value2"}, - target_term_key=None, - ) - transform_fn = build_transform_from_rule(rule) - transformed_annotation = transform_fn(annotation) - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].value == "value2" - - -def test_derive_tag_rule_target_term_none( - annotation: data.SoundEventAnnotation, - derivation_registry: DerivationRegistry, - term1: data.Term, -): - def derivation_func(x: str) -> str: - return x + "_derived" - - derivation_registry.register("my_derivation", derivation_func) - - rule = DeriveTagRule( - rule_type="derive_tag", - source_term_key="term1", - derivation_function="my_derivation", - target_term_key=None, - ) - transform_fn = build_transform_from_rule( - rule, - derivation_registry=derivation_registry, - ) - transformed_annotation = transform_fn(annotation) - - assert len(transformed_annotation.tags) == 2 - assert transformed_annotation.tags[0].term == term1 - assert transformed_annotation.tags[0].value == "value1" - assert transformed_annotation.tags[1].term == term1 - assert transformed_annotation.tags[1].value == "value1_derived" - - -def test_build_transformation_from_config_empty( - annotation: data.SoundEventAnnotation, -): - config = TransformConfig(rules=[]) - transform = build_transformation_from_config(config) - transformed_annotation = transform(annotation) - assert transformed_annotation == annotation diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 15e15e9..9c27941 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -5,7 +5,6 @@ from soundevent import data from batdetect2.targets import TargetConfig, build_targets from batdetect2.targets.rois import AnchorBBoxMapperConfig -from batdetect2.targets.terms import TagInfo from batdetect2.train.labels import generate_heatmaps recording = data.Recording( @@ -26,7 +25,7 @@ clip = data.Clip( def test_generated_heatmap_are_non_zero_at_correct_positions( sample_target_config: TargetConfig, - pippip_tag: TagInfo, + pippip_tag: data.Tag, ): config = sample_target_config.model_copy( update=dict(