mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Remove stale tests
This commit is contained in:
parent
709b6355c2
commit
cf6d0d1ccc
@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import get_term_from_key
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = []
|
||||
@ -92,15 +90,15 @@ def annotation_to_sound_event(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
key=label_key, # type: ignore
|
||||
value=annotation.label,
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
key=event_key, # type: ignore
|
||||
value=annotation.event,
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(individual_key),
|
||||
key=individual_key, # type: ignore
|
||||
value=str(annotation.individual),
|
||||
),
|
||||
],
|
||||
@ -125,7 +123,7 @@ def file_annotation_to_clip(
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
key=label_key, # type: ignore
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
@ -157,7 +155,8 @@ def file_annotation_to_clip_annotation(
|
||||
notes=notes,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key), value=file_annotation.label
|
||||
key=label_key, # type: ignore
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
sound_events=[
|
||||
|
||||
@ -68,7 +68,10 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.typing.models import DetectionModel
|
||||
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
@ -122,7 +125,7 @@ class Model(LightningModule):
|
||||
self.targets = targets
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||
spec = self.preprocessor(wav)
|
||||
outputs = self.detector(spec)
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
@ -57,14 +57,9 @@ from batdetect2.targets.rois import (
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermInfo,
|
||||
TermRegistry,
|
||||
call_type,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
individual,
|
||||
register_term,
|
||||
)
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
@ -100,7 +95,6 @@ __all__ = [
|
||||
"TargetClass",
|
||||
"TargetConfig",
|
||||
"Targets",
|
||||
"TermInfo",
|
||||
"TransformConfig",
|
||||
"build_generic_class_tags",
|
||||
"build_roi_mapper",
|
||||
@ -112,7 +106,6 @@ __all__ = [
|
||||
"get_class_names_from_config",
|
||||
"get_derivation",
|
||||
"get_tag_from_info",
|
||||
"get_term_from_key",
|
||||
"individual",
|
||||
"load_classes_config",
|
||||
"load_decoder_from_config",
|
||||
@ -123,7 +116,6 @@ __all__ = [
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
"register_derivation",
|
||||
"register_term",
|
||||
]
|
||||
|
||||
|
||||
@ -534,7 +526,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
|
||||
def build_targets(
|
||||
config: Optional[TargetConfig] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
@ -550,9 +541,6 @@ def build_targets(
|
||||
----------
|
||||
config : TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use for resolving term keys. Defaults
|
||||
to the global `batdetect2.targets.terms.term_registry`.
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The DerivationRegistry instance to use for resolving derivation
|
||||
function names. Defaults to the global
|
||||
@ -578,25 +566,15 @@ def build_targets(
|
||||
)
|
||||
|
||||
filter_fn = (
|
||||
build_sound_event_filter(
|
||||
config.filtering,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
build_sound_event_filter(config.filtering)
|
||||
if config.filtering
|
||||
else None
|
||||
)
|
||||
encode_fn = build_sound_event_encoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
decode_fn = build_sound_event_decoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
encode_fn = build_sound_event_encoder(config.classes)
|
||||
decode_fn = build_sound_event_decoder(config.classes)
|
||||
transform_fn = (
|
||||
build_transformation_from_config(
|
||||
config.transforms,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
if config.transforms
|
||||
@ -604,10 +582,7 @@ def build_targets(
|
||||
)
|
||||
roi_mapper = build_roi_mapper(config.roi)
|
||||
class_names = get_class_names_from_config(config.classes)
|
||||
generic_class_tags = build_generic_class_tags(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
generic_class_tags = build_generic_class_tags(config.classes)
|
||||
roi_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classes.classes
|
||||
@ -629,7 +604,6 @@ def build_targets(
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
@ -645,8 +619,6 @@ def load_targets(
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing
|
||||
the target configuration. If None, the entire file content is used.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use. Defaults to the global default.
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The DerivationRegistry instance to use. Defaults to the global
|
||||
default.
|
||||
@ -670,11 +642,7 @@ def load_targets(
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return build_targets(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
return build_targets(config, derivation_registry=derivation_registry)
|
||||
|
||||
|
||||
def iterate_encoded_sound_events(
|
||||
|
||||
@ -10,8 +10,6 @@ from batdetect2.targets.rois import ROIMapperConfig
|
||||
from batdetect2.targets.terms import (
|
||||
GENERIC_CLASS_KEY,
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
)
|
||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||
@ -295,10 +293,7 @@ def _encode_with_multiple_classifiers(
|
||||
return None
|
||||
|
||||
|
||||
def build_sound_event_encoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
def build_sound_event_encoder(config: ClassesConfig) -> SoundEventEncoder:
|
||||
"""Build a sound event encoder function from the classes configuration.
|
||||
|
||||
The returned encoder function iterates through the class definitions in the
|
||||
@ -333,8 +328,7 @@ def build_sound_event_encoder(
|
||||
partial(
|
||||
is_target_class,
|
||||
tags={
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in class_info.tags
|
||||
get_tag_from_info(tag_info) for tag_info in class_info.tags
|
||||
},
|
||||
match_all=class_info.match_type == "all",
|
||||
),
|
||||
@ -391,7 +385,6 @@ def _decode_class(
|
||||
|
||||
def build_sound_event_decoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Build a sound event decoder function from the classes configuration.
|
||||
@ -433,8 +426,7 @@ def build_sound_event_decoder(
|
||||
else class_info.tags
|
||||
)
|
||||
mapping[class_info.name] = [
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in tags_to_use
|
||||
get_tag_from_info(tag_info) for tag_info in tags_to_use
|
||||
]
|
||||
|
||||
return partial(
|
||||
@ -444,10 +436,7 @@ def build_sound_event_decoder(
|
||||
)
|
||||
|
||||
|
||||
def build_generic_class_tags(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Tag]:
|
||||
def build_generic_class_tags(config: ClassesConfig) -> List[data.Tag]:
|
||||
"""Extract and build the list of tags for the generic class from config.
|
||||
|
||||
Converts the list of `TagInfo` objects defined in `config.generic_class`
|
||||
@ -472,10 +461,7 @@ def build_generic_class_tags(
|
||||
If a term key specified in `config.generic_class` is not found in the
|
||||
provided `term_registry`.
|
||||
"""
|
||||
return [
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in config.generic_class
|
||||
]
|
||||
return [get_tag_from_info(tag_info) for tag_info in config.generic_class]
|
||||
|
||||
|
||||
def load_classes_config(
|
||||
@ -509,9 +495,7 @@ def load_classes_config(
|
||||
|
||||
|
||||
def load_encoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> SoundEventEncoder:
|
||||
"""Load a class encoder function directly from a configuration file.
|
||||
|
||||
@ -546,13 +530,12 @@ def load_encoder_from_config(
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_sound_event_encoder(config, term_registry=term_registry)
|
||||
return build_sound_event_encoder(config)
|
||||
|
||||
|
||||
def load_decoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Load a class decoder function directly from a configuration file.
|
||||
@ -594,6 +577,5 @@ def load_decoder_from_config(
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_sound_event_decoder(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
raise_on_unmapped=raise_on_unmapped,
|
||||
)
|
||||
|
||||
@ -8,8 +8,6 @@ from soundevent import data
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
)
|
||||
from batdetect2.typing.targets import SoundEventFilter
|
||||
@ -146,10 +144,7 @@ def equal_tags(
|
||||
return tags == sound_event_tags
|
||||
|
||||
|
||||
def build_filter_from_rule(
|
||||
rule: FilterRule,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter:
|
||||
"""Creates a callable filter function from a single FilterRule.
|
||||
|
||||
Parameters
|
||||
@ -168,10 +163,7 @@ def build_filter_from_rule(
|
||||
ValueError
|
||||
If the rule contains an invalid `match_type`.
|
||||
"""
|
||||
tag_set = {
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in rule.tags
|
||||
}
|
||||
tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags}
|
||||
|
||||
if rule.match_type == "any":
|
||||
return partial(has_any_tag, tags=tag_set)
|
||||
@ -235,7 +227,6 @@ class FilterConfig(BaseConfig):
|
||||
|
||||
def build_sound_event_filter(
|
||||
config: FilterConfig,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Builds a merged filter function from a FilterConfig object.
|
||||
|
||||
@ -252,10 +243,7 @@ def build_sound_event_filter(
|
||||
SoundEventFilter
|
||||
A single callable filter function that applies all defined rules.
|
||||
"""
|
||||
filters = [
|
||||
build_filter_from_rule(rule, term_registry=term_registry)
|
||||
for rule in config.rules
|
||||
]
|
||||
filters = [build_filter_from_rule(rule) for rule in config.rules]
|
||||
return partial(_passes_all_filters, filters=filters)
|
||||
|
||||
|
||||
@ -281,9 +269,7 @@ def load_filter_config(
|
||||
|
||||
|
||||
def load_filter_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
path: data.PathLike, field: Optional[str] = None,
|
||||
) -> SoundEventFilter:
|
||||
"""Loads filter configuration from a file and builds the filter function.
|
||||
|
||||
@ -304,4 +290,4 @@ def load_filter_from_config(
|
||||
The final merged filter function ready to be used.
|
||||
"""
|
||||
config = load_filter_config(path=path, field=field)
|
||||
return build_sound_event_filter(config, term_registry=term_registry)
|
||||
return build_sound_event_filter(config)
|
||||
|
||||
@ -1,33 +1,22 @@
|
||||
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
||||
"""Manages the vocabulary for defining training targets.
|
||||
|
||||
This module provides the necessary tools to declare, register, and manage the
|
||||
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
||||
sub-package. It establishes a consistent vocabulary for filtering,
|
||||
transforming, and classifying sound events based on their annotations (Tags).
|
||||
|
||||
The core component is the `TermRegistry`, which maps unique string keys
|
||||
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
||||
terms using simple, consistent keys in configuration files and code.
|
||||
|
||||
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
||||
programmatically, or loaded from external configuration files (e.g., YAML).
|
||||
Terms can be pre-defined, loaded from the `soundevent.terms` library or defined
|
||||
programmatically.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from inspect import getmembers
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"data_source",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
@ -98,255 +87,6 @@ terms.register_term_set(
|
||||
)
|
||||
|
||||
|
||||
class TermRegistry(Mapping[str, data.Term]):
|
||||
"""Manages a registry mapping unique keys to Term definitions.
|
||||
|
||||
This class acts as the central repository for the vocabulary of terms
|
||||
used within the target definition process. It allows registering terms
|
||||
with simple string keys and retrieving them consistently.
|
||||
"""
|
||||
|
||||
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
||||
"""Initializes the TermRegistry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
terms : dict[str, soundevent.data.Term], optional
|
||||
An optional dictionary of initial key-to-Term mappings
|
||||
to populate the registry with. Defaults to an empty registry.
|
||||
"""
|
||||
self._terms: Dict[str, data.Term] = terms or {}
|
||||
|
||||
def __getitem__(self, key: str) -> data.Term:
|
||||
return self._terms[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._terms)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._terms)
|
||||
|
||||
def add_term(self, key: str, term: data.Term) -> None:
|
||||
"""Adds a Term object to the registry with the specified key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key to associate with the term.
|
||||
term : soundevent.data.Term
|
||||
The soundevent.data.Term object to register.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term with the provided key already exists in the
|
||||
registry.
|
||||
"""
|
||||
if key in self._terms:
|
||||
raise KeyError("A term with the provided key already exists.")
|
||||
|
||||
self._terms[key] = term
|
||||
|
||||
def get_term(self, key: str) -> data.Term:
|
||||
"""Retrieves a registered term by its unique key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key of the term to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If no term with the specified key is found, with a
|
||||
helpful message suggesting listing available keys.
|
||||
"""
|
||||
try:
|
||||
return self._terms[key]
|
||||
except KeyError as err:
|
||||
raise KeyError(
|
||||
"No term found for key "
|
||||
f"'{key}'. Ensure it is registered or loaded. "
|
||||
f"Available keys: {', '.join(self.get_keys())}"
|
||||
) from err
|
||||
|
||||
def add_custom_term(
|
||||
self,
|
||||
key: str,
|
||||
name: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
definition: Optional[str] = None,
|
||||
) -> data.Term:
|
||||
"""Creates a new Term from attributes and adds it to the registry.
|
||||
|
||||
This is useful for defining terms directly in code or when loading
|
||||
from configuration files where only attributes are provided.
|
||||
|
||||
If optional fields (`name`, `label`, `definition`) are not provided,
|
||||
reasonable defaults are used (`key` for name/label, "Unknown" for
|
||||
definition).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key for the new term.
|
||||
name : str, optional
|
||||
The name for the new term (defaults to `key`).
|
||||
uri : str, optional
|
||||
The URI for the new term (optional).
|
||||
label : str, optional
|
||||
The display label for the new term (defaults to `key`).
|
||||
definition : str, optional
|
||||
The definition for the new term (defaults to "Unknown").
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The newly created and registered soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term with the provided key already exists.
|
||||
"""
|
||||
term = data.Term(
|
||||
name=name or key,
|
||||
label=label or key,
|
||||
uri=uri,
|
||||
definition=definition or "Unknown",
|
||||
)
|
||||
self.add_term(key, term)
|
||||
return term
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""Returns a list of all keys currently registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return list(self._terms.keys())
|
||||
|
||||
def get_terms(self) -> List[data.Term]:
|
||||
"""Returns a list of all registered terms.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[soundevent.data.Term]
|
||||
A list containing all registered Term objects.
|
||||
"""
|
||||
return list(self._terms.values())
|
||||
|
||||
def remove_key(self, key: str) -> None:
|
||||
del self._terms[key]
|
||||
|
||||
|
||||
default_term_registry = TermRegistry(
|
||||
terms=dict(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
("event", call_type),
|
||||
("species", terms.scientific_name),
|
||||
("individual", individual),
|
||||
("data_source", data_source),
|
||||
(GENERIC_CLASS_KEY, generic_class),
|
||||
]
|
||||
)
|
||||
)
|
||||
"""The default, globally accessible TermRegistry instance.
|
||||
|
||||
It is pre-populated with standard terms from `soundevent.terms` and common
|
||||
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
||||
Functions in this module use this registry by default unless another instance
|
||||
is explicitly passed.
|
||||
"""
|
||||
|
||||
|
||||
def get_term_from_key(
|
||||
key: str,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Term:
|
||||
"""Convenience function to retrieve a term by key from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique key of the term to retrieve.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to search in. Defaults to the global
|
||||
`registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the key is not found in the specified registry.
|
||||
"""
|
||||
term = terms.get_term(key)
|
||||
|
||||
if term:
|
||||
return term
|
||||
|
||||
term_registry = term_registry or default_term_registry
|
||||
return term_registry.get_term(key)
|
||||
|
||||
|
||||
def get_term_keys(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[str]:
|
||||
"""Convenience function to get all registered keys from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return term_registry.get_keys()
|
||||
|
||||
|
||||
def get_terms(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Term]:
|
||||
"""Convenience function to get all registered terms from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[soundevent.data.Term]
|
||||
A list containing all registered Term objects.
|
||||
"""
|
||||
return term_registry.get_terms()
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
"""Represents information needed to define a specific Tag.
|
||||
|
||||
@ -360,31 +100,25 @@ class TagInfo(BaseModel):
|
||||
value : str
|
||||
The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
||||
key : str, default="class"
|
||||
The key (alias) of the term associated with this tag, as
|
||||
registered in the TermRegistry. Defaults to "class", implying
|
||||
it represents a classification target label by default.
|
||||
The key (alias) of the term associated with this tag. Defaults to
|
||||
"class", implying it represents a classification target label by
|
||||
default.
|
||||
"""
|
||||
|
||||
value: str
|
||||
key: str = GENERIC_CLASS_KEY
|
||||
|
||||
|
||||
def get_tag_from_info(
|
||||
tag_info: TagInfo,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Tag:
|
||||
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||
|
||||
Looks up the term using the key in the provided `tag_info` from the
|
||||
specified registry and constructs a Tag object.
|
||||
Looks up the term using the key in the provided `tag_info` and constructs a
|
||||
Tag object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag_info : TagInfo
|
||||
The TagInfo object containing the value and term key.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use for term lookup. Defaults to the
|
||||
global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -394,132 +128,11 @@ def get_tag_from_info(
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the term key specified in `tag_info.key` is not found
|
||||
in the registry.
|
||||
If the term key specified in `tag_info.key` is not found.
|
||||
"""
|
||||
term_registry = term_registry or default_term_registry
|
||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||
term = terms.get_term(tag_info.key)
|
||||
|
||||
if not term:
|
||||
raise KeyError(f"Key {tag_info.key} not found")
|
||||
|
||||
return data.Tag(term=term, value=tag_info.value)
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
"""Represents the definition of a Term within a configuration file.
|
||||
|
||||
This model allows users to define custom terms directly in configuration
|
||||
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
||||
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
key : str
|
||||
The unique key (alias) that will be used to register and
|
||||
reference this term.
|
||||
label : str, optional
|
||||
The optional display label for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
name : str, optional
|
||||
The optional formal name for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
uri : str, optional
|
||||
The optional URI identifying the term (e.g., from a standard
|
||||
vocabulary).
|
||||
definition : str, optional
|
||||
The optional textual definition of the term. Defaults to
|
||||
"Unknown" if not provided during registration.
|
||||
"""
|
||||
|
||||
key: str
|
||||
label: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
uri: Optional[str] = None
|
||||
definition: Optional[str] = None
|
||||
|
||||
|
||||
class TermConfig(BaseModel):
|
||||
"""Pydantic schema for loading a list of term definitions from config.
|
||||
|
||||
This model typically corresponds to a section in a configuration file
|
||||
(e.g., YAML) containing a list of term definitions to be registered.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
terms : list[TermInfo]
|
||||
A list of TermInfo objects, each defining a term to be
|
||||
registered. Defaults to an empty list.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Example YAML structure:
|
||||
|
||||
```yaml
|
||||
terms:
|
||||
- key: species
|
||||
uri: dwc:scientificName
|
||||
label: Scientific Name
|
||||
- key: my_custom_term
|
||||
name: My Custom Term
|
||||
definition: Describes a specific project attribute.
|
||||
# ... more TermInfo definitions
|
||||
```
|
||||
"""
|
||||
|
||||
terms: List[TermInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
def load_terms_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> Dict[str, data.Term]:
|
||||
"""Loads term definitions from a configuration file and registers them.
|
||||
|
||||
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
||||
extracts the list of TermInfo definitions, and adds each one as a
|
||||
custom term to the specified TermRegistry instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
The path to the configuration file.
|
||||
field : str, optional
|
||||
Optional key indicating a specific section within the config
|
||||
file where the 'terms' list is located. If None, expects the
|
||||
list directly at the top level or within a structure matching
|
||||
TermConfig schema.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to add the loaded terms to. Defaults to
|
||||
the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, soundevent.data.Term]
|
||||
A dictionary mapping the keys of the newly added terms to their
|
||||
corresponding Term objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the TermConfig schema.
|
||||
KeyError
|
||||
If a term key loaded from the config conflicts with a key
|
||||
already present in the registry.
|
||||
"""
|
||||
data = load_config(path, schema=TermConfig, field=field)
|
||||
return {
|
||||
info.key: term_registry.add_custom_term(
|
||||
info.key,
|
||||
name=info.name,
|
||||
uri=info.uri,
|
||||
label=info.label,
|
||||
definition=info.definition,
|
||||
)
|
||||
for info in data.terms
|
||||
}
|
||||
|
||||
|
||||
def register_term(
|
||||
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
||||
) -> None:
|
||||
registry.add_term(key, term)
|
||||
|
||||
@ -12,15 +12,10 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, get_tag_from_info
|
||||
|
||||
__all__ = [
|
||||
"DerivationRegistry",
|
||||
@ -466,7 +461,6 @@ TranformationRule = Annotated[
|
||||
def build_transform_from_rule(
|
||||
rule: TranformationRule,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a specific SoundEventTransformation function from a rule config.
|
||||
|
||||
@ -497,29 +491,21 @@ def build_transform_from_rule(
|
||||
If dynamic import of a derivation function fails.
|
||||
"""
|
||||
if rule.rule_type == "replace":
|
||||
source = get_tag_from_info(
|
||||
rule.original,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target = get_tag_from_info(
|
||||
rule.replacement,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
source = get_tag_from_info(rule.original)
|
||||
target = get_tag_from_info(rule.replacement)
|
||||
return partial(replace_tag_transform, source=source, target=target)
|
||||
|
||||
if rule.rule_type == "derive_tag":
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
source_term = terms.get_term(rule.source_term_key)
|
||||
target_term = (
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
terms.get_term(rule.target_term_key)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
|
||||
if source_term is None or target_term is None:
|
||||
raise KeyError("Terms not found")
|
||||
|
||||
derivation = get_derivation(
|
||||
key=rule.derivation_function,
|
||||
import_derivation=rule.import_derivation,
|
||||
@ -534,18 +520,16 @@ def build_transform_from_rule(
|
||||
)
|
||||
|
||||
if rule.rule_type == "map_value":
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
source_term = terms.get_term(rule.source_term_key)
|
||||
target_term = (
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
terms.get_term(rule.target_term_key)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
|
||||
if source_term is None or target_term is None:
|
||||
raise KeyError("Terms not found")
|
||||
|
||||
return partial(
|
||||
map_value_transform,
|
||||
source_term=source_term,
|
||||
@ -555,6 +539,7 @@ def build_transform_from_rule(
|
||||
|
||||
# Handle unknown rule type
|
||||
valid_options = ["replace", "derive_tag", "map_value"]
|
||||
|
||||
# Should be caught by Pydantic validation, but good practice
|
||||
raise ValueError(
|
||||
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
||||
@ -565,7 +550,6 @@ def build_transform_from_rule(
|
||||
def build_transformation_from_config(
|
||||
config: TransformConfig,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a composite transformation function from a TransformConfig.
|
||||
|
||||
@ -591,7 +575,6 @@ def build_transformation_from_config(
|
||||
build_transform_from_rule(
|
||||
rule,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
for rule in config.rules
|
||||
]
|
||||
@ -640,7 +623,6 @@ def load_transformation_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Load transformation config from a file and build the final function.
|
||||
|
||||
@ -678,7 +660,6 @@ def load_transformation_from_config(
|
||||
return build_transformation_from_config(
|
||||
config,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ class TrainingModule(L.LightningModule):
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.model(spec)
|
||||
return self.model.detector(spec)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.model.detector(batch.spec)
|
||||
|
||||
@ -99,7 +99,7 @@ def train(
|
||||
module = build_training_module(
|
||||
model,
|
||||
config,
|
||||
batches_per_epoch=len(train_dataloader),
|
||||
t_max=config.train.t_max * len(train_dataloader),
|
||||
)
|
||||
|
||||
logger.info("Starting main training loop...")
|
||||
@ -113,15 +113,16 @@ def train(
|
||||
|
||||
def build_training_module(
|
||||
model: Model,
|
||||
config: FullTrainingConfig,
|
||||
batches_per_epoch: int,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
t_max: int = 200,
|
||||
) -> TrainingModule:
|
||||
config = config or FullTrainingConfig()
|
||||
loss = build_loss(config=config.train.loss)
|
||||
return TrainingModule(
|
||||
model=model,
|
||||
loss=loss,
|
||||
learning_rate=config.train.learning_rate,
|
||||
t_max=config.train.t_max * batches_per_epoch,
|
||||
t_max=t_max,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
TermRegistry,
|
||||
build_targets,
|
||||
call_type,
|
||||
)
|
||||
@ -355,18 +354,6 @@ def create_annotation_project():
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_term_registry() -> TermRegistry:
|
||||
"""Fixture for a sample TermRegistry."""
|
||||
registry = TermRegistry()
|
||||
registry.add_custom_term("class")
|
||||
registry.add_custom_term("order")
|
||||
registry.add_custom_term("species")
|
||||
registry.add_custom_term("call_type")
|
||||
registry.add_custom_term("quality")
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_preprocessor() -> PreprocessorProtocol:
|
||||
return build_preprocessor()
|
||||
@ -399,7 +386,6 @@ def pippip_tag() -> TagInfo:
|
||||
|
||||
@pytest.fixture
|
||||
def sample_target_config(
|
||||
sample_term_registry: TermRegistry,
|
||||
bat_tag: TagInfo,
|
||||
noise_tag: TagInfo,
|
||||
myomyo_tag: TagInfo,
|
||||
@ -422,12 +408,8 @@ def sample_target_config(
|
||||
@pytest.fixture
|
||||
def sample_targets(
|
||||
sample_target_config: TargetConfig,
|
||||
sample_term_registry: TermRegistry,
|
||||
) -> TargetProtocol:
|
||||
return build_targets(
|
||||
sample_target_config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
return build_targets(sample_target_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -443,10 +425,8 @@ def sample_labeller(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_clipper(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
) -> ClipperProtocol:
|
||||
return build_clipper(preprocessor=sample_preprocessor)
|
||||
def sample_clipper() -> ClipperProtocol:
|
||||
return build_clipper()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
|
||||
|
||||
def test_can_override_default_roi_mapper_per_class(
|
||||
create_temp_yaml: Callable[..., Path],
|
||||
recording: data.Recording,
|
||||
sample_term_registry,
|
||||
):
|
||||
yaml_content = """
|
||||
roi:
|
||||
@ -36,11 +34,13 @@ def test_can_override_default_roi_mapper_per_class(
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = load_target_config(config_path)
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
targets = build_targets(config)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
species = terms.get_term("species")
|
||||
assert species is not None
|
||||
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
@ -62,7 +62,6 @@ def test_can_override_default_roi_mapper_per_class(
|
||||
# TODO: rename this test function
|
||||
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||
create_temp_yaml,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
@ -89,11 +88,12 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = load_target_config(config_path)
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
targets = build_targets(config)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
species = terms.get_term("species")
|
||||
assert species is not None
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
|
||||
@ -1,98 +1,7 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
load_terms_from_config,
|
||||
)
|
||||
|
||||
|
||||
def test_term_registry_initialization():
|
||||
registry = TermRegistry()
|
||||
assert registry._terms == {}
|
||||
|
||||
initial_terms = {
|
||||
"test_term": data.Term(name="test", label="Test", definition="test")
|
||||
}
|
||||
registry = TermRegistry(terms=initial_terms)
|
||||
assert registry._terms == initial_terms
|
||||
|
||||
|
||||
def test_term_registry_add_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
assert registry._terms["test_key"] == term
|
||||
|
||||
|
||||
def test_term_registry_get_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
retrieved_term = registry.get_term("test_key")
|
||||
assert retrieved_term == term
|
||||
|
||||
|
||||
def test_term_registry_add_custom_term():
|
||||
registry = TermRegistry()
|
||||
term = registry.add_custom_term(
|
||||
"custom_key", name="custom", label="Custom", definition="A custom term"
|
||||
)
|
||||
assert registry._terms["custom_key"] == term
|
||||
assert term.name == "custom"
|
||||
assert term.label == "Custom"
|
||||
assert term.definition == "A custom term"
|
||||
|
||||
|
||||
def test_term_registry_add_duplicate_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
with pytest.raises(KeyError):
|
||||
registry.add_term("test_key", term)
|
||||
|
||||
|
||||
def test_term_registry_get_term_not_found():
|
||||
registry = TermRegistry()
|
||||
with pytest.raises(KeyError):
|
||||
registry.get_term("non_existent_key")
|
||||
|
||||
|
||||
def test_term_registry_get_keys():
|
||||
registry = TermRegistry()
|
||||
term1 = data.Term(name="test1", label="Test1", definition="test")
|
||||
term2 = data.Term(name="test2", label="Test2", definition="test")
|
||||
registry.add_term("key1", term1)
|
||||
registry.add_term("key2", term2)
|
||||
keys = registry.get_keys()
|
||||
assert set(keys) == {"key1", "key2"}
|
||||
|
||||
|
||||
def test_get_term_from_key():
|
||||
term = terms.get_term_from_key("event")
|
||||
assert term == terms.call_type
|
||||
|
||||
custom_registry = TermRegistry()
|
||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||
custom_registry.add_term("custom_key", custom_term)
|
||||
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
|
||||
assert term == custom_term
|
||||
|
||||
|
||||
def test_get_term_keys():
|
||||
keys = terms.get_term_keys()
|
||||
assert "event" in keys
|
||||
assert "individual" in keys
|
||||
assert terms.GENERIC_CLASS_KEY in keys
|
||||
|
||||
custom_registry = TermRegistry()
|
||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||
custom_registry.add_term("custom_key", custom_term)
|
||||
keys = terms.get_term_keys(term_registry=custom_registry)
|
||||
assert "custom_key" in keys
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
|
||||
|
||||
def test_tag_info_and_get_tag_from_info():
|
||||
@ -106,74 +15,3 @@ def test_get_tag_from_info_key_not_found():
|
||||
tag_info = TagInfo(value="test", key="non_existent_key")
|
||||
with pytest.raises(KeyError):
|
||||
terms.get_tag_from_info(tag_info)
|
||||
|
||||
|
||||
def test_load_terms_from_config(tmp_path):
|
||||
term_registry = TermRegistry()
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "species",
|
||||
"name": "dwc:scientificName",
|
||||
"label": "Scientific Name",
|
||||
},
|
||||
{
|
||||
"key": "my_custom_term",
|
||||
"name": "soundevent:custom_term",
|
||||
"definition": "Describes a specific project attribute",
|
||||
},
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
loaded_terms = load_terms_from_config(
|
||||
config_file,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
assert "species" in loaded_terms
|
||||
assert "my_custom_term" in loaded_terms
|
||||
assert loaded_terms["species"].name == "dwc:scientificName"
|
||||
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
|
||||
|
||||
|
||||
def test_load_terms_from_config_file_not_found():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_terms_from_config("non_existent_file.yaml")
|
||||
|
||||
|
||||
def test_load_terms_from_config_validation_error(tmp_path):
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "species",
|
||||
"uri": "dwc:scientificName",
|
||||
"label": 123,
|
||||
}, # Invalid label type
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_terms_from_config(config_file)
|
||||
|
||||
|
||||
def test_load_terms_from_config_key_already_exists(tmp_path):
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "event",
|
||||
"uri": "dwc:scientificName",
|
||||
"label": "Scientific Name",
|
||||
}, # Duplicate key
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
load_terms_from_config(config_file)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import (
|
||||
DeriveTagRule,
|
||||
@ -11,31 +11,36 @@ from batdetect2.targets import (
|
||||
TransformConfig,
|
||||
build_transformation_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TermRegistry
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
build_transform_from_rule,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term_registry():
|
||||
return TermRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def derivation_registry():
|
||||
return DerivationRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term1(term_registry: TermRegistry) -> data.Term:
|
||||
return term_registry.add_custom_term(key="term1")
|
||||
def term1() -> data.Term:
|
||||
term = data.Term(label="Term 1", definition="unknown", name="test:term1")
|
||||
terms.add_term(term, key="term1", force=True)
|
||||
return term
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term2(term_registry: TermRegistry) -> data.Term:
|
||||
return term_registry.add_custom_term(key="term2")
|
||||
def term2() -> data.Term:
|
||||
term = data.Term(label="Term 2", definition="unknown", name="test:term2")
|
||||
terms.add_term(term, key="term2", force=True)
|
||||
return term
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term3() -> data.Term:
|
||||
term = data.Term(label="Term 3", definition="unknown", name="test:term3")
|
||||
terms.add_term(term, key="term3", force=True)
|
||||
return term
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -47,46 +52,45 @@ def annotation(
|
||||
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def annotation2(
|
||||
sound_event: data.SoundEvent,
|
||||
term2: data.Term,
|
||||
) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event, tags=[data.Tag(term=term2, value="value2")]
|
||||
)
|
||||
|
||||
def test_map_value_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
|
||||
def test_map_value_rule(annotation: data.SoundEventAnnotation):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transform_fn = build_transform_from_rule(rule)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_map_value_rule_no_match(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
def test_map_value_rule_no_match(annotation: data.SoundEventAnnotation):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"other_value": "value2"},
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transform_fn = build_transform_from_rule(rule)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
|
||||
|
||||
def test_replace_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term2: data.Term,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
def test_replace_rule(annotation: data.SoundEventAnnotation, term2: data.Term):
|
||||
rule = ReplaceRule(
|
||||
rule_type="replace",
|
||||
original=TagInfo(key="term1", value="value1"),
|
||||
replacement=TagInfo(key="term2", value="value2"),
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transform_fn = build_transform_from_rule(rule)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term2
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
@ -94,7 +98,7 @@ def test_replace_rule(
|
||||
|
||||
def test_replace_rule_no_match(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term1: data.Term,
|
||||
term2: data.Term,
|
||||
):
|
||||
rule = ReplaceRule(
|
||||
@ -102,16 +106,19 @@ def test_replace_rule_no_match(
|
||||
original=TagInfo(key="term1", value="wrong_value"),
|
||||
replacement=TagInfo(key="term2", value="value2"),
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transform_fn = build_transform_from_rule(rule)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].key == "term1"
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].term != term2
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
|
||||
|
||||
def test_build_transformation_from_config(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
annotation2: data.SoundEventAnnotation,
|
||||
term1: data.Term,
|
||||
term2: data.Term,
|
||||
term3: data.Term,
|
||||
):
|
||||
config = TransformConfig(
|
||||
rules=[
|
||||
@ -127,20 +134,20 @@ def test_build_transformation_from_config(
|
||||
),
|
||||
]
|
||||
)
|
||||
term_registry.add_custom_term("term2")
|
||||
term_registry.add_custom_term("term3")
|
||||
transform = build_transformation_from_config(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
transform = build_transformation_from_config(config)
|
||||
|
||||
transformed_annotation = transform(annotation)
|
||||
assert transformed_annotation.tags[0].key == "term1"
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].term != term2
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
transformed_annotation = transform(annotation2)
|
||||
assert transformed_annotation.tags[0].term == term3
|
||||
assert transformed_annotation.tags[0].value == "value3"
|
||||
|
||||
|
||||
def test_derive_tag_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
@ -156,7 +163,6 @@ def test_derive_tag_rule(
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
@ -170,7 +176,6 @@ def test_derive_tag_rule(
|
||||
|
||||
def test_derive_tag_rule_keep_source_false(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
@ -187,7 +192,6 @@ def test_derive_tag_rule_keep_source_false(
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
@ -199,7 +203,6 @@ def test_derive_tag_rule_keep_source_false(
|
||||
|
||||
def test_derive_tag_rule_target_term(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
term2: data.Term,
|
||||
@ -217,7 +220,6 @@ def test_derive_tag_rule_target_term(
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
@ -231,7 +233,6 @@ def test_derive_tag_rule_target_term(
|
||||
|
||||
def test_derive_tag_rule_import_derivation(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term1: data.Term,
|
||||
tmp_path: Path,
|
||||
):
|
||||
@ -256,7 +257,7 @@ def my_imported_derivation(x: str) -> str:
|
||||
derivation_function="temp_derivation.my_imported_derivation",
|
||||
import_derivation=True,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transform_fn = build_transform_from_rule(rule)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
@ -269,14 +270,14 @@ def my_imported_derivation(x: str) -> str:
|
||||
sys.path.remove(str(tmp_path))
|
||||
|
||||
|
||||
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
|
||||
def test_derive_tag_rule_invalid_derivation():
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="nonexistent_derivation",
|
||||
)
|
||||
with pytest.raises(KeyError):
|
||||
build_transform_from_rule(rule, term_registry=term_registry)
|
||||
build_transform_from_rule(rule)
|
||||
|
||||
|
||||
def test_build_transform_from_rule_invalid_rule_type():
|
||||
@ -291,7 +292,6 @@ def test_build_transform_from_rule_invalid_rule_type():
|
||||
|
||||
def test_map_value_rule_target_term(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term2: data.Term,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
@ -300,7 +300,7 @@ def test_map_value_rule_target_term(
|
||||
value_mapping={"value1": "value2"},
|
||||
target_term_key="term2",
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transform_fn = build_transform_from_rule(rule)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term2
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
@ -308,7 +308,6 @@ def test_map_value_rule_target_term(
|
||||
|
||||
def test_map_value_rule_target_term_none(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
@ -317,7 +316,7 @@ def test_map_value_rule_target_term_none(
|
||||
value_mapping={"value1": "value2"},
|
||||
target_term_key=None,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transform_fn = build_transform_from_rule(rule)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
@ -325,7 +324,6 @@ def test_map_value_rule_target_term_none(
|
||||
|
||||
def test_derive_tag_rule_target_term_none(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
@ -342,7 +340,6 @@ def test_derive_tag_rule_target_term_none(
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
@ -1,170 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train.augmentations import (
|
||||
add_echo,
|
||||
mix_audio,
|
||||
)
|
||||
from batdetect2.train.clips import select_subclip
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
|
||||
|
||||
|
||||
def test_mix_examples(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
recording1 = create_recording()
|
||||
recording2 = create_recording()
|
||||
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
||||
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
mixed = mix_audio(
|
||||
example1,
|
||||
example2,
|
||||
weight=0.3,
|
||||
preprocessor=sample_preprocessor,
|
||||
)
|
||||
|
||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||
def test_mix_examples_of_different_durations(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
duration1: float,
|
||||
duration2: float,
|
||||
):
|
||||
recording1 = create_recording()
|
||||
recording2 = create_recording()
|
||||
|
||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
|
||||
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
|
||||
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
mixed = mix_audio(
|
||||
example1,
|
||||
example2,
|
||||
weight=0.3,
|
||||
preprocessor=sample_preprocessor,
|
||||
)
|
||||
|
||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||
|
||||
|
||||
def test_add_echo(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
recording1 = create_recording()
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
with_echo = add_echo(
|
||||
original,
|
||||
preprocessor=sample_preprocessor,
|
||||
delay=0.1,
|
||||
weight=0.3,
|
||||
)
|
||||
|
||||
assert with_echo.spectrogram.shape == original.spectrogram.shape
|
||||
torch.testing.assert_close(
|
||||
with_echo.size_heatmap,
|
||||
original.size_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
with_echo.class_heatmap,
|
||||
original.class_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
with_echo.detection_heatmap,
|
||||
original.detection_heatmap,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
|
||||
|
||||
def test_selected_random_subclip_has_the_correct_width(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
recording1 = create_recording()
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
subclip = select_subclip(
|
||||
original,
|
||||
input_samplerate=256_000,
|
||||
output_samplerate=1000,
|
||||
start=0,
|
||||
duration=0.512,
|
||||
)
|
||||
assert subclip.spectrogram.shape[1] == 512
|
||||
@ -1,27 +0,0 @@
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train import generate_train_example
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipLabeller,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
|
||||
|
||||
def test_default_clip_size_is_correct(
|
||||
sample_clipper: ClipperProtocol,
|
||||
sample_labeller: ClipLabeller,
|
||||
sample_audio_loader: AudioLoader,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
):
|
||||
example = generate_train_example(
|
||||
clip_annotation=clip_annotation,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
clip, _, _ = sample_clipper(example)
|
||||
assert clip.spectrogram.shape == (1, 128, 256)
|
||||
@ -56,7 +56,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation,
|
||||
torch.rand([100, 100]),
|
||||
torch.rand([1, 100, 100]),
|
||||
min_freq=0,
|
||||
max_freq=100,
|
||||
targets=targets,
|
||||
@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
assert size_heatmap[1, 10, 10] == 20
|
||||
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
||||
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
||||
assert detection_heatmap[10, 10] == 1.0
|
||||
assert detection_heatmap[0, 10, 10] == 1.0
|
||||
|
||||
@ -4,14 +4,16 @@ import lightning as L
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||
from batdetect2.train.train import build_training_module
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
def build_default_module():
|
||||
model = build_model()
|
||||
config = FullTrainingConfig()
|
||||
return build_training_module(config)
|
||||
return build_training_module(model, config=config)
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
@ -32,14 +34,14 @@ def test_can_save_checkpoint(
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
|
||||
wav = torch.tensor(sample_audio_loader.load_clip(clip))
|
||||
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
|
||||
|
||||
spec1 = module.model.preprocessor(wav)
|
||||
spec2 = recovered.model.preprocessor(wav)
|
||||
|
||||
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||
|
||||
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
|
||||
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
|
||||
output1 = module(spec1.unsqueeze(0))
|
||||
output2 = recovered(spec2.unsqueeze(0))
|
||||
|
||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||
|
||||
@ -1,230 +0,0 @@
|
||||
import pytest
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def build_from_config(
|
||||
create_temp_yaml,
|
||||
):
|
||||
def build(yaml_content):
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
targets_config = load_target_config(config_path, field="targets")
|
||||
preprocessing_config = load_preprocessing_config(
|
||||
config_path,
|
||||
field="preprocessing",
|
||||
)
|
||||
labels_config = load_label_config(config_path, field="labels")
|
||||
postprocessing_config = load_postprocess_config(
|
||||
config_path,
|
||||
field="postprocessing",
|
||||
)
|
||||
|
||||
targets = build_targets(targets_config)
|
||||
preprocessor = build_preprocessor(preprocessing_config)
|
||||
labeller = build_clip_labeler(
|
||||
targets=targets,
|
||||
config=labels_config,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=postprocessing_config,
|
||||
)
|
||||
|
||||
return targets, preprocessor, labeller, postprocessor
|
||||
|
||||
return build
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object(
|
||||
sample_audio_loader: AudioLoader,
|
||||
build_from_config,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
],
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(
|
||||
clip_annotation,
|
||||
sample_audio_loader,
|
||||
preprocessor,
|
||||
labeller,
|
||||
)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||
0
|
||||
),
|
||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
|
||||
assert isinstance(predictions, data.ClipPrediction)
|
||||
assert len(predictions.sound_events) == 1
|
||||
|
||||
recovered = predictions.sound_events[0]
|
||||
assert recovered.sound_event.geometry is not None
|
||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||
recovered.sound_event.geometry.coordinates
|
||||
)
|
||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||
geometry.coordinates
|
||||
)
|
||||
|
||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
sample_audio_loader: AudioLoader,
|
||||
build_from_config,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(
|
||||
clip_annotation,
|
||||
sample_audio_loader,
|
||||
preprocessor,
|
||||
labeller,
|
||||
)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||
0
|
||||
),
|
||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
|
||||
assert isinstance(predictions, data.ClipPrediction)
|
||||
assert len(predictions.sound_events) == 1
|
||||
|
||||
recovered = predictions.sound_events[0]
|
||||
assert recovered.sound_event.geometry is not None
|
||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||
recovered.sound_event.geometry.coordinates
|
||||
)
|
||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||
geometry.coordinates
|
||||
)
|
||||
|
||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||
Loading…
Reference in New Issue
Block a user