mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Remove stale tests
This commit is contained in:
parent
709b6355c2
commit
cf6d0d1ccc
@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import get_term_from_key
|
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@ -92,15 +90,15 @@ def annotation_to_sound_event(
|
|||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key),
|
key=label_key, # type: ignore
|
||||||
value=annotation.label,
|
value=annotation.label,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(event_key),
|
key=event_key, # type: ignore
|
||||||
value=annotation.event,
|
value=annotation.event,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(individual_key),
|
key=individual_key, # type: ignore
|
||||||
value=str(annotation.individual),
|
value=str(annotation.individual),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -125,7 +123,7 @@ def file_annotation_to_clip(
|
|||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key),
|
key=label_key, # type: ignore
|
||||||
value=file_annotation.label,
|
value=file_annotation.label,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -157,7 +155,8 @@ def file_annotation_to_clip_annotation(
|
|||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key), value=file_annotation.label
|
key=label_key, # type: ignore
|
||||||
|
value=file_annotation.label,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
sound_events=[
|
sound_events=[
|
||||||
|
|||||||
@ -68,7 +68,10 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
|||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.typing.models import DetectionModel
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
|
from batdetect2.typing.postprocess import (
|
||||||
|
DetectionsTensor,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
)
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
@ -122,7 +125,7 @@ class Model(LightningModule):
|
|||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
|
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|||||||
@ -57,14 +57,9 @@ from batdetect2.targets.rois import (
|
|||||||
)
|
)
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermInfo,
|
|
||||||
TermRegistry,
|
|
||||||
call_type,
|
call_type,
|
||||||
default_term_registry,
|
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
get_term_from_key,
|
|
||||||
individual,
|
individual,
|
||||||
register_term,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.transform import (
|
from batdetect2.targets.transform import (
|
||||||
DerivationRegistry,
|
DerivationRegistry,
|
||||||
@ -100,7 +95,6 @@ __all__ = [
|
|||||||
"TargetClass",
|
"TargetClass",
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"Targets",
|
"Targets",
|
||||||
"TermInfo",
|
|
||||||
"TransformConfig",
|
"TransformConfig",
|
||||||
"build_generic_class_tags",
|
"build_generic_class_tags",
|
||||||
"build_roi_mapper",
|
"build_roi_mapper",
|
||||||
@ -112,7 +106,6 @@ __all__ = [
|
|||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
"get_derivation",
|
"get_derivation",
|
||||||
"get_tag_from_info",
|
"get_tag_from_info",
|
||||||
"get_term_from_key",
|
|
||||||
"individual",
|
"individual",
|
||||||
"load_classes_config",
|
"load_classes_config",
|
||||||
"load_decoder_from_config",
|
"load_decoder_from_config",
|
||||||
@ -123,7 +116,6 @@ __all__ = [
|
|||||||
"load_transformation_config",
|
"load_transformation_config",
|
||||||
"load_transformation_from_config",
|
"load_transformation_from_config",
|
||||||
"register_derivation",
|
"register_derivation",
|
||||||
"register_term",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -534,7 +526,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
|
|
||||||
def build_targets(
|
def build_targets(
|
||||||
config: Optional[TargetConfig] = None,
|
config: Optional[TargetConfig] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
@ -550,9 +541,6 @@ def build_targets(
|
|||||||
----------
|
----------
|
||||||
config : TargetConfig
|
config : TargetConfig
|
||||||
The loaded and validated unified target configuration object.
|
The loaded and validated unified target configuration object.
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use for resolving term keys. Defaults
|
|
||||||
to the global `batdetect2.targets.terms.term_registry`.
|
|
||||||
derivation_registry : DerivationRegistry, optional
|
derivation_registry : DerivationRegistry, optional
|
||||||
The DerivationRegistry instance to use for resolving derivation
|
The DerivationRegistry instance to use for resolving derivation
|
||||||
function names. Defaults to the global
|
function names. Defaults to the global
|
||||||
@ -578,25 +566,15 @@ def build_targets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
filter_fn = (
|
filter_fn = (
|
||||||
build_sound_event_filter(
|
build_sound_event_filter(config.filtering)
|
||||||
config.filtering,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
if config.filtering
|
if config.filtering
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
encode_fn = build_sound_event_encoder(
|
encode_fn = build_sound_event_encoder(config.classes)
|
||||||
config.classes,
|
decode_fn = build_sound_event_decoder(config.classes)
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
decode_fn = build_sound_event_decoder(
|
|
||||||
config.classes,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
transform_fn = (
|
transform_fn = (
|
||||||
build_transformation_from_config(
|
build_transformation_from_config(
|
||||||
config.transforms,
|
config.transforms,
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
)
|
)
|
||||||
if config.transforms
|
if config.transforms
|
||||||
@ -604,10 +582,7 @@ def build_targets(
|
|||||||
)
|
)
|
||||||
roi_mapper = build_roi_mapper(config.roi)
|
roi_mapper = build_roi_mapper(config.roi)
|
||||||
class_names = get_class_names_from_config(config.classes)
|
class_names = get_class_names_from_config(config.classes)
|
||||||
generic_class_tags = build_generic_class_tags(
|
generic_class_tags = build_generic_class_tags(config.classes)
|
||||||
config.classes,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
roi_overrides = {
|
roi_overrides = {
|
||||||
class_config.name: build_roi_mapper(class_config.roi)
|
class_config.name: build_roi_mapper(class_config.roi)
|
||||||
for class_config in config.classes.classes
|
for class_config in config.classes.classes
|
||||||
@ -629,7 +604,6 @@ def build_targets(
|
|||||||
def load_targets(
|
def load_targets(
|
||||||
config_path: data.PathLike,
|
config_path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Load a Targets object directly from a configuration file.
|
"""Load a Targets object directly from a configuration file.
|
||||||
@ -645,8 +619,6 @@ def load_targets(
|
|||||||
field : str, optional
|
field : str, optional
|
||||||
Dot-separated path to a nested section within the file containing
|
Dot-separated path to a nested section within the file containing
|
||||||
the target configuration. If None, the entire file content is used.
|
the target configuration. If None, the entire file content is used.
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use. Defaults to the global default.
|
|
||||||
derivation_registry : DerivationRegistry, optional
|
derivation_registry : DerivationRegistry, optional
|
||||||
The DerivationRegistry instance to use. Defaults to the global
|
The DerivationRegistry instance to use. Defaults to the global
|
||||||
default.
|
default.
|
||||||
@ -670,11 +642,7 @@ def load_targets(
|
|||||||
config_path,
|
config_path,
|
||||||
field=field,
|
field=field,
|
||||||
)
|
)
|
||||||
return build_targets(
|
return build_targets(config, derivation_registry=derivation_registry)
|
||||||
config,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_encoded_sound_events(
|
def iterate_encoded_sound_events(
|
||||||
|
|||||||
@ -10,8 +10,6 @@ from batdetect2.targets.rois import ROIMapperConfig
|
|||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
GENERIC_CLASS_KEY,
|
GENERIC_CLASS_KEY,
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
|
||||||
default_term_registry,
|
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||||
@ -295,10 +293,7 @@ def _encode_with_multiple_classifiers(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_encoder(
|
def build_sound_event_encoder(config: ClassesConfig) -> SoundEventEncoder:
|
||||||
config: ClassesConfig,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventEncoder:
|
|
||||||
"""Build a sound event encoder function from the classes configuration.
|
"""Build a sound event encoder function from the classes configuration.
|
||||||
|
|
||||||
The returned encoder function iterates through the class definitions in the
|
The returned encoder function iterates through the class definitions in the
|
||||||
@ -333,8 +328,7 @@ def build_sound_event_encoder(
|
|||||||
partial(
|
partial(
|
||||||
is_target_class,
|
is_target_class,
|
||||||
tags={
|
tags={
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
get_tag_from_info(tag_info) for tag_info in class_info.tags
|
||||||
for tag_info in class_info.tags
|
|
||||||
},
|
},
|
||||||
match_all=class_info.match_type == "all",
|
match_all=class_info.match_type == "all",
|
||||||
),
|
),
|
||||||
@ -391,7 +385,6 @@ def _decode_class(
|
|||||||
|
|
||||||
def build_sound_event_decoder(
|
def build_sound_event_decoder(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Build a sound event decoder function from the classes configuration.
|
"""Build a sound event decoder function from the classes configuration.
|
||||||
@ -433,8 +426,7 @@ def build_sound_event_decoder(
|
|||||||
else class_info.tags
|
else class_info.tags
|
||||||
)
|
)
|
||||||
mapping[class_info.name] = [
|
mapping[class_info.name] = [
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
get_tag_from_info(tag_info) for tag_info in tags_to_use
|
||||||
for tag_info in tags_to_use
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return partial(
|
return partial(
|
||||||
@ -444,10 +436,7 @@ def build_sound_event_decoder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_generic_class_tags(
|
def build_generic_class_tags(config: ClassesConfig) -> List[data.Tag]:
|
||||||
config: ClassesConfig,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> List[data.Tag]:
|
|
||||||
"""Extract and build the list of tags for the generic class from config.
|
"""Extract and build the list of tags for the generic class from config.
|
||||||
|
|
||||||
Converts the list of `TagInfo` objects defined in `config.generic_class`
|
Converts the list of `TagInfo` objects defined in `config.generic_class`
|
||||||
@ -472,10 +461,7 @@ def build_generic_class_tags(
|
|||||||
If a term key specified in `config.generic_class` is not found in the
|
If a term key specified in `config.generic_class` is not found in the
|
||||||
provided `term_registry`.
|
provided `term_registry`.
|
||||||
"""
|
"""
|
||||||
return [
|
return [get_tag_from_info(tag_info) for tag_info in config.generic_class]
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
|
||||||
for tag_info in config.generic_class
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def load_classes_config(
|
def load_classes_config(
|
||||||
@ -509,9 +495,7 @@ def load_classes_config(
|
|||||||
|
|
||||||
|
|
||||||
def load_encoder_from_config(
|
def load_encoder_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike, field: Optional[str] = None
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
"""Load a class encoder function directly from a configuration file.
|
"""Load a class encoder function directly from a configuration file.
|
||||||
|
|
||||||
@ -546,13 +530,12 @@ def load_encoder_from_config(
|
|||||||
provided `term_registry` during the build process.
|
provided `term_registry` during the build process.
|
||||||
"""
|
"""
|
||||||
config = load_classes_config(path, field=field)
|
config = load_classes_config(path, field=field)
|
||||||
return build_sound_event_encoder(config, term_registry=term_registry)
|
return build_sound_event_encoder(config)
|
||||||
|
|
||||||
|
|
||||||
def load_decoder_from_config(
|
def load_decoder_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Load a class decoder function directly from a configuration file.
|
"""Load a class decoder function directly from a configuration file.
|
||||||
@ -594,6 +577,5 @@ def load_decoder_from_config(
|
|||||||
config = load_classes_config(path, field=field)
|
config = load_classes_config(path, field=field)
|
||||||
return build_sound_event_decoder(
|
return build_sound_event_decoder(
|
||||||
config,
|
config,
|
||||||
term_registry=term_registry,
|
|
||||||
raise_on_unmapped=raise_on_unmapped,
|
raise_on_unmapped=raise_on_unmapped,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,8 +8,6 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
|
||||||
default_term_registry,
|
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import SoundEventFilter
|
from batdetect2.typing.targets import SoundEventFilter
|
||||||
@ -146,10 +144,7 @@ def equal_tags(
|
|||||||
return tags == sound_event_tags
|
return tags == sound_event_tags
|
||||||
|
|
||||||
|
|
||||||
def build_filter_from_rule(
|
def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter:
|
||||||
rule: FilterRule,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventFilter:
|
|
||||||
"""Creates a callable filter function from a single FilterRule.
|
"""Creates a callable filter function from a single FilterRule.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -168,10 +163,7 @@ def build_filter_from_rule(
|
|||||||
ValueError
|
ValueError
|
||||||
If the rule contains an invalid `match_type`.
|
If the rule contains an invalid `match_type`.
|
||||||
"""
|
"""
|
||||||
tag_set = {
|
tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags}
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
|
||||||
for tag_info in rule.tags
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.match_type == "any":
|
if rule.match_type == "any":
|
||||||
return partial(has_any_tag, tags=tag_set)
|
return partial(has_any_tag, tags=tag_set)
|
||||||
@ -235,7 +227,6 @@ class FilterConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_sound_event_filter(
|
def build_sound_event_filter(
|
||||||
config: FilterConfig,
|
config: FilterConfig,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Builds a merged filter function from a FilterConfig object.
|
"""Builds a merged filter function from a FilterConfig object.
|
||||||
|
|
||||||
@ -252,10 +243,7 @@ def build_sound_event_filter(
|
|||||||
SoundEventFilter
|
SoundEventFilter
|
||||||
A single callable filter function that applies all defined rules.
|
A single callable filter function that applies all defined rules.
|
||||||
"""
|
"""
|
||||||
filters = [
|
filters = [build_filter_from_rule(rule) for rule in config.rules]
|
||||||
build_filter_from_rule(rule, term_registry=term_registry)
|
|
||||||
for rule in config.rules
|
|
||||||
]
|
|
||||||
return partial(_passes_all_filters, filters=filters)
|
return partial(_passes_all_filters, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
@ -281,9 +269,7 @@ def load_filter_config(
|
|||||||
|
|
||||||
|
|
||||||
def load_filter_from_config(
|
def load_filter_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike, field: Optional[str] = None,
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Loads filter configuration from a file and builds the filter function.
|
"""Loads filter configuration from a file and builds the filter function.
|
||||||
|
|
||||||
@ -304,4 +290,4 @@ def load_filter_from_config(
|
|||||||
The final merged filter function ready to be used.
|
The final merged filter function ready to be used.
|
||||||
"""
|
"""
|
||||||
config = load_filter_config(path=path, field=field)
|
config = load_filter_config(path=path, field=field)
|
||||||
return build_sound_event_filter(config, term_registry=term_registry)
|
return build_sound_event_filter(config)
|
||||||
|
|||||||
@ -1,33 +1,22 @@
|
|||||||
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
"""Manages the vocabulary for defining training targets.
|
||||||
|
|
||||||
This module provides the necessary tools to declare, register, and manage the
|
This module provides the necessary tools to declare, register, and manage the
|
||||||
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
||||||
sub-package. It establishes a consistent vocabulary for filtering,
|
sub-package. It establishes a consistent vocabulary for filtering,
|
||||||
transforming, and classifying sound events based on their annotations (Tags).
|
transforming, and classifying sound events based on their annotations (Tags).
|
||||||
|
|
||||||
The core component is the `TermRegistry`, which maps unique string keys
|
Terms can be pre-defined, loaded from the `soundevent.terms` library or defined
|
||||||
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
programmatically.
|
||||||
terms using simple, consistent keys in configuration files and code.
|
|
||||||
|
|
||||||
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
|
||||||
programmatically, or loaded from external configuration files (e.g., YAML).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from pydantic import BaseModel
|
||||||
from inspect import getmembers
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.configs import load_config
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"call_type",
|
"call_type",
|
||||||
"individual",
|
"individual",
|
||||||
"data_source",
|
"data_source",
|
||||||
"get_tag_from_info",
|
"get_tag_from_info",
|
||||||
"TermInfo",
|
|
||||||
"TagInfo",
|
"TagInfo",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -98,255 +87,6 @@ terms.register_term_set(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TermRegistry(Mapping[str, data.Term]):
|
|
||||||
"""Manages a registry mapping unique keys to Term definitions.
|
|
||||||
|
|
||||||
This class acts as the central repository for the vocabulary of terms
|
|
||||||
used within the target definition process. It allows registering terms
|
|
||||||
with simple string keys and retrieving them consistently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
|
||||||
"""Initializes the TermRegistry.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
terms : dict[str, soundevent.data.Term], optional
|
|
||||||
An optional dictionary of initial key-to-Term mappings
|
|
||||||
to populate the registry with. Defaults to an empty registry.
|
|
||||||
"""
|
|
||||||
self._terms: Dict[str, data.Term] = terms or {}
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> data.Term:
|
|
||||||
return self._terms[key]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self._terms)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self._terms)
|
|
||||||
|
|
||||||
def add_term(self, key: str, term: data.Term) -> None:
|
|
||||||
"""Adds a Term object to the registry with the specified key.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique string key to associate with the term.
|
|
||||||
term : soundevent.data.Term
|
|
||||||
The soundevent.data.Term object to register.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a term with the provided key already exists in the
|
|
||||||
registry.
|
|
||||||
"""
|
|
||||||
if key in self._terms:
|
|
||||||
raise KeyError("A term with the provided key already exists.")
|
|
||||||
|
|
||||||
self._terms[key] = term
|
|
||||||
|
|
||||||
def get_term(self, key: str) -> data.Term:
|
|
||||||
"""Retrieves a registered term by its unique key.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique string key of the term to retrieve.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Term
|
|
||||||
The corresponding soundevent.data.Term object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If no term with the specified key is found, with a
|
|
||||||
helpful message suggesting listing available keys.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self._terms[key]
|
|
||||||
except KeyError as err:
|
|
||||||
raise KeyError(
|
|
||||||
"No term found for key "
|
|
||||||
f"'{key}'. Ensure it is registered or loaded. "
|
|
||||||
f"Available keys: {', '.join(self.get_keys())}"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
def add_custom_term(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
uri: Optional[str] = None,
|
|
||||||
label: Optional[str] = None,
|
|
||||||
definition: Optional[str] = None,
|
|
||||||
) -> data.Term:
|
|
||||||
"""Creates a new Term from attributes and adds it to the registry.
|
|
||||||
|
|
||||||
This is useful for defining terms directly in code or when loading
|
|
||||||
from configuration files where only attributes are provided.
|
|
||||||
|
|
||||||
If optional fields (`name`, `label`, `definition`) are not provided,
|
|
||||||
reasonable defaults are used (`key` for name/label, "Unknown" for
|
|
||||||
definition).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique string key for the new term.
|
|
||||||
name : str, optional
|
|
||||||
The name for the new term (defaults to `key`).
|
|
||||||
uri : str, optional
|
|
||||||
The URI for the new term (optional).
|
|
||||||
label : str, optional
|
|
||||||
The display label for the new term (defaults to `key`).
|
|
||||||
definition : str, optional
|
|
||||||
The definition for the new term (defaults to "Unknown").
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Term
|
|
||||||
The newly created and registered soundevent.data.Term object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a term with the provided key already exists.
|
|
||||||
"""
|
|
||||||
term = data.Term(
|
|
||||||
name=name or key,
|
|
||||||
label=label or key,
|
|
||||||
uri=uri,
|
|
||||||
definition=definition or "Unknown",
|
|
||||||
)
|
|
||||||
self.add_term(key, term)
|
|
||||||
return term
|
|
||||||
|
|
||||||
def get_keys(self) -> List[str]:
|
|
||||||
"""Returns a list of all keys currently registered.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[str]
|
|
||||||
A list of strings representing the keys of all registered terms.
|
|
||||||
"""
|
|
||||||
return list(self._terms.keys())
|
|
||||||
|
|
||||||
def get_terms(self) -> List[data.Term]:
|
|
||||||
"""Returns a list of all registered terms.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[soundevent.data.Term]
|
|
||||||
A list containing all registered Term objects.
|
|
||||||
"""
|
|
||||||
return list(self._terms.values())
|
|
||||||
|
|
||||||
def remove_key(self, key: str) -> None:
|
|
||||||
del self._terms[key]
|
|
||||||
|
|
||||||
|
|
||||||
default_term_registry = TermRegistry(
|
|
||||||
terms=dict(
|
|
||||||
[
|
|
||||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
|
||||||
("event", call_type),
|
|
||||||
("species", terms.scientific_name),
|
|
||||||
("individual", individual),
|
|
||||||
("data_source", data_source),
|
|
||||||
(GENERIC_CLASS_KEY, generic_class),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
"""The default, globally accessible TermRegistry instance.
|
|
||||||
|
|
||||||
It is pre-populated with standard terms from `soundevent.terms` and common
|
|
||||||
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
|
||||||
Functions in this module use this registry by default unless another instance
|
|
||||||
is explicitly passed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_term_from_key(
|
|
||||||
key: str,
|
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> data.Term:
|
|
||||||
"""Convenience function to retrieve a term by key from a registry.
|
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
|
||||||
instance is provided.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique key of the term to retrieve.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to search in. Defaults to the global
|
|
||||||
`registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Term
|
|
||||||
The corresponding soundevent.data.Term object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If the key is not found in the specified registry.
|
|
||||||
"""
|
|
||||||
term = terms.get_term(key)
|
|
||||||
|
|
||||||
if term:
|
|
||||||
return term
|
|
||||||
|
|
||||||
term_registry = term_registry or default_term_registry
|
|
||||||
return term_registry.get_term(key)
|
|
||||||
|
|
||||||
|
|
||||||
def get_term_keys(
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Convenience function to get all registered keys from a registry.
|
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
|
||||||
instance is provided.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[str]
|
|
||||||
A list of strings representing the keys of all registered terms.
|
|
||||||
"""
|
|
||||||
return term_registry.get_keys()
|
|
||||||
|
|
||||||
|
|
||||||
def get_terms(
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> List[data.Term]:
|
|
||||||
"""Convenience function to get all registered terms from a registry.
|
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
|
||||||
instance is provided.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[soundevent.data.Term]
|
|
||||||
A list containing all registered Term objects.
|
|
||||||
"""
|
|
||||||
return term_registry.get_terms()
|
|
||||||
|
|
||||||
|
|
||||||
class TagInfo(BaseModel):
|
class TagInfo(BaseModel):
|
||||||
"""Represents information needed to define a specific Tag.
|
"""Represents information needed to define a specific Tag.
|
||||||
|
|
||||||
@ -360,31 +100,25 @@ class TagInfo(BaseModel):
|
|||||||
value : str
|
value : str
|
||||||
The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
||||||
key : str, default="class"
|
key : str, default="class"
|
||||||
The key (alias) of the term associated with this tag, as
|
The key (alias) of the term associated with this tag. Defaults to
|
||||||
registered in the TermRegistry. Defaults to "class", implying
|
"class", implying it represents a classification target label by
|
||||||
it represents a classification target label by default.
|
default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
value: str
|
value: str
|
||||||
key: str = GENERIC_CLASS_KEY
|
key: str = GENERIC_CLASS_KEY
|
||||||
|
|
||||||
|
|
||||||
def get_tag_from_info(
|
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||||
tag_info: TagInfo,
|
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> data.Tag:
|
|
||||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||||
|
|
||||||
Looks up the term using the key in the provided `tag_info` from the
|
Looks up the term using the key in the provided `tag_info` and constructs a
|
||||||
specified registry and constructs a Tag object.
|
Tag object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
tag_info : TagInfo
|
tag_info : TagInfo
|
||||||
The TagInfo object containing the value and term key.
|
The TagInfo object containing the value and term key.
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use for term lookup. Defaults to the
|
|
||||||
global `registry`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -394,132 +128,11 @@ def get_tag_from_info(
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
KeyError
|
KeyError
|
||||||
If the term key specified in `tag_info.key` is not found
|
If the term key specified in `tag_info.key` is not found.
|
||||||
in the registry.
|
|
||||||
"""
|
"""
|
||||||
term_registry = term_registry or default_term_registry
|
term = terms.get_term(tag_info.key)
|
||||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
|
||||||
|
if not term:
|
||||||
|
raise KeyError(f"Key {tag_info.key} not found")
|
||||||
|
|
||||||
return data.Tag(term=term, value=tag_info.value)
|
return data.Tag(term=term, value=tag_info.value)
|
||||||
|
|
||||||
|
|
||||||
class TermInfo(BaseModel):
|
|
||||||
"""Represents the definition of a Term within a configuration file.
|
|
||||||
|
|
||||||
This model allows users to define custom terms directly in configuration
|
|
||||||
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
|
||||||
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique key (alias) that will be used to register and
|
|
||||||
reference this term.
|
|
||||||
label : str, optional
|
|
||||||
The optional display label for the term. Defaults to `key`
|
|
||||||
if not provided during registration.
|
|
||||||
name : str, optional
|
|
||||||
The optional formal name for the term. Defaults to `key`
|
|
||||||
if not provided during registration.
|
|
||||||
uri : str, optional
|
|
||||||
The optional URI identifying the term (e.g., from a standard
|
|
||||||
vocabulary).
|
|
||||||
definition : str, optional
|
|
||||||
The optional textual definition of the term. Defaults to
|
|
||||||
"Unknown" if not provided during registration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
label: Optional[str] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
uri: Optional[str] = None
|
|
||||||
definition: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TermConfig(BaseModel):
|
|
||||||
"""Pydantic schema for loading a list of term definitions from config.
|
|
||||||
|
|
||||||
This model typically corresponds to a section in a configuration file
|
|
||||||
(e.g., YAML) containing a list of term definitions to be registered.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
terms : list[TermInfo]
|
|
||||||
A list of TermInfo objects, each defining a term to be
|
|
||||||
registered. Defaults to an empty list.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
Example YAML structure:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
terms:
|
|
||||||
- key: species
|
|
||||||
uri: dwc:scientificName
|
|
||||||
label: Scientific Name
|
|
||||||
- key: my_custom_term
|
|
||||||
name: My Custom Term
|
|
||||||
definition: Describes a specific project attribute.
|
|
||||||
# ... more TermInfo definitions
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
terms: List[TermInfo] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_terms_from_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> Dict[str, data.Term]:
|
|
||||||
"""Loads term definitions from a configuration file and registers them.
|
|
||||||
|
|
||||||
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
|
||||||
extracts the list of TermInfo definitions, and adds each one as a
|
|
||||||
custom term to the specified TermRegistry instance.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
The path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Optional key indicating a specific section within the config
|
|
||||||
file where the 'terms' list is located. If None, expects the
|
|
||||||
list directly at the top level or within a structure matching
|
|
||||||
TermConfig schema.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to add the loaded terms to. Defaults to
|
|
||||||
the global `registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict[str, soundevent.data.Term]
|
|
||||||
A dictionary mapping the keys of the newly added terms to their
|
|
||||||
corresponding Term objects.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the TermConfig schema.
|
|
||||||
KeyError
|
|
||||||
If a term key loaded from the config conflicts with a key
|
|
||||||
already present in the registry.
|
|
||||||
"""
|
|
||||||
data = load_config(path, schema=TermConfig, field=field)
|
|
||||||
return {
|
|
||||||
info.key: term_registry.add_custom_term(
|
|
||||||
info.key,
|
|
||||||
name=info.name,
|
|
||||||
uri=info.uri,
|
|
||||||
label=info.label,
|
|
||||||
definition=info.definition,
|
|
||||||
)
|
|
||||||
for info in data.terms
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def register_term(
|
|
||||||
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
|
||||||
) -> None:
|
|
||||||
registry.add_term(key, term)
|
|
||||||
|
|||||||
@ -12,15 +12,10 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import TagInfo, get_tag_from_info
|
||||||
TagInfo,
|
|
||||||
TermRegistry,
|
|
||||||
get_tag_from_info,
|
|
||||||
get_term_from_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DerivationRegistry",
|
"DerivationRegistry",
|
||||||
@ -466,7 +461,6 @@ TranformationRule = Annotated[
|
|||||||
def build_transform_from_rule(
|
def build_transform_from_rule(
|
||||||
rule: TranformationRule,
|
rule: TranformationRule,
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a specific SoundEventTransformation function from a rule config.
|
"""Build a specific SoundEventTransformation function from a rule config.
|
||||||
|
|
||||||
@ -497,29 +491,21 @@ def build_transform_from_rule(
|
|||||||
If dynamic import of a derivation function fails.
|
If dynamic import of a derivation function fails.
|
||||||
"""
|
"""
|
||||||
if rule.rule_type == "replace":
|
if rule.rule_type == "replace":
|
||||||
source = get_tag_from_info(
|
source = get_tag_from_info(rule.original)
|
||||||
rule.original,
|
target = get_tag_from_info(rule.replacement)
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
target = get_tag_from_info(
|
|
||||||
rule.replacement,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
return partial(replace_tag_transform, source=source, target=target)
|
return partial(replace_tag_transform, source=source, target=target)
|
||||||
|
|
||||||
if rule.rule_type == "derive_tag":
|
if rule.rule_type == "derive_tag":
|
||||||
source_term = get_term_from_key(
|
source_term = terms.get_term(rule.source_term_key)
|
||||||
rule.source_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
target_term = (
|
target_term = (
|
||||||
get_term_from_key(
|
terms.get_term(rule.target_term_key)
|
||||||
rule.target_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
if rule.target_term_key
|
if rule.target_term_key
|
||||||
else source_term
|
else source_term
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if source_term is None or target_term is None:
|
||||||
|
raise KeyError("Terms not found")
|
||||||
|
|
||||||
derivation = get_derivation(
|
derivation = get_derivation(
|
||||||
key=rule.derivation_function,
|
key=rule.derivation_function,
|
||||||
import_derivation=rule.import_derivation,
|
import_derivation=rule.import_derivation,
|
||||||
@ -534,18 +520,16 @@ def build_transform_from_rule(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if rule.rule_type == "map_value":
|
if rule.rule_type == "map_value":
|
||||||
source_term = get_term_from_key(
|
source_term = terms.get_term(rule.source_term_key)
|
||||||
rule.source_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
target_term = (
|
target_term = (
|
||||||
get_term_from_key(
|
terms.get_term(rule.target_term_key)
|
||||||
rule.target_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
if rule.target_term_key
|
if rule.target_term_key
|
||||||
else source_term
|
else source_term
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if source_term is None or target_term is None:
|
||||||
|
raise KeyError("Terms not found")
|
||||||
|
|
||||||
return partial(
|
return partial(
|
||||||
map_value_transform,
|
map_value_transform,
|
||||||
source_term=source_term,
|
source_term=source_term,
|
||||||
@ -555,6 +539,7 @@ def build_transform_from_rule(
|
|||||||
|
|
||||||
# Handle unknown rule type
|
# Handle unknown rule type
|
||||||
valid_options = ["replace", "derive_tag", "map_value"]
|
valid_options = ["replace", "derive_tag", "map_value"]
|
||||||
|
|
||||||
# Should be caught by Pydantic validation, but good practice
|
# Should be caught by Pydantic validation, but good practice
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
||||||
@ -565,7 +550,6 @@ def build_transform_from_rule(
|
|||||||
def build_transformation_from_config(
|
def build_transformation_from_config(
|
||||||
config: TransformConfig,
|
config: TransformConfig,
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a composite transformation function from a TransformConfig.
|
"""Build a composite transformation function from a TransformConfig.
|
||||||
|
|
||||||
@ -591,7 +575,6 @@ def build_transformation_from_config(
|
|||||||
build_transform_from_rule(
|
build_transform_from_rule(
|
||||||
rule,
|
rule,
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
)
|
||||||
for rule in config.rules
|
for rule in config.rules
|
||||||
]
|
]
|
||||||
@ -640,7 +623,6 @@ def load_transformation_from_config(
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Load transformation config from a file and build the final function.
|
"""Load transformation config from a file and build the final function.
|
||||||
|
|
||||||
@ -678,7 +660,6 @@ def load_transformation_from_config(
|
|||||||
return build_transformation_from_config(
|
return build_transformation_from_config(
|
||||||
config,
|
config,
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.save_hyperparameters(logger=False)
|
self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.model(spec)
|
return self.model.detector(spec)
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample):
|
def training_step(self, batch: TrainExample):
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ def train(
|
|||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model,
|
model,
|
||||||
config,
|
config,
|
||||||
batches_per_epoch=len(train_dataloader),
|
t_max=config.train.t_max * len(train_dataloader),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
@ -113,15 +113,16 @@ def train(
|
|||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model: Model,
|
model: Model,
|
||||||
config: FullTrainingConfig,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
batches_per_epoch: int,
|
t_max: int = 200,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
|
config = config or FullTrainingConfig()
|
||||||
loss = build_loss(config=config.train.loss)
|
loss = build_loss(config=config.train.loss)
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model=model,
|
model=model,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
learning_rate=config.train.learning_rate,
|
learning_rate=config.train.learning_rate,
|
||||||
t_max=config.train.t_max * batches_per_epoch,
|
t_max=t_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from batdetect2.preprocess import build_preprocessor
|
|||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
TermRegistry,
|
|
||||||
build_targets,
|
build_targets,
|
||||||
call_type,
|
call_type,
|
||||||
)
|
)
|
||||||
@ -355,18 +354,6 @@ def create_annotation_project():
|
|||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_term_registry() -> TermRegistry:
|
|
||||||
"""Fixture for a sample TermRegistry."""
|
|
||||||
registry = TermRegistry()
|
|
||||||
registry.add_custom_term("class")
|
|
||||||
registry.add_custom_term("order")
|
|
||||||
registry.add_custom_term("species")
|
|
||||||
registry.add_custom_term("call_type")
|
|
||||||
registry.add_custom_term("quality")
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_preprocessor() -> PreprocessorProtocol:
|
def sample_preprocessor() -> PreprocessorProtocol:
|
||||||
return build_preprocessor()
|
return build_preprocessor()
|
||||||
@ -399,7 +386,6 @@ def pippip_tag() -> TagInfo:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_target_config(
|
def sample_target_config(
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
bat_tag: TagInfo,
|
bat_tag: TagInfo,
|
||||||
noise_tag: TagInfo,
|
noise_tag: TagInfo,
|
||||||
myomyo_tag: TagInfo,
|
myomyo_tag: TagInfo,
|
||||||
@ -422,12 +408,8 @@ def sample_target_config(
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_targets(
|
def sample_targets(
|
||||||
sample_target_config: TargetConfig,
|
sample_target_config: TargetConfig,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
) -> TargetProtocol:
|
) -> TargetProtocol:
|
||||||
return build_targets(
|
return build_targets(sample_target_config)
|
||||||
sample_target_config,
|
|
||||||
term_registry=sample_term_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -443,10 +425,8 @@ def sample_labeller(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_clipper(
|
def sample_clipper() -> ClipperProtocol:
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
return build_clipper()
|
||||||
) -> ClipperProtocol:
|
|
||||||
return build_clipper(preprocessor=sample_preprocessor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -1,16 +1,14 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from soundevent import data
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
from batdetect2.targets.terms import get_term_from_key
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_override_default_roi_mapper_per_class(
|
def test_can_override_default_roi_mapper_per_class(
|
||||||
create_temp_yaml: Callable[..., Path],
|
create_temp_yaml: Callable[..., Path],
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
sample_term_registry,
|
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
roi:
|
roi:
|
||||||
@ -36,11 +34,13 @@ def test_can_override_default_roi_mapper_per_class(
|
|||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = load_target_config(config_path)
|
config = load_target_config(config_path)
|
||||||
targets = build_targets(config, term_registry=sample_term_registry)
|
targets = build_targets(config)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
species = terms.get_term("species")
|
||||||
|
assert species is not None
|
||||||
|
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
@ -62,7 +62,6 @@ def test_can_override_default_roi_mapper_per_class(
|
|||||||
# TODO: rename this test function
|
# TODO: rename this test function
|
||||||
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||||
create_temp_yaml,
|
create_temp_yaml,
|
||||||
sample_term_registry,
|
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
@ -89,11 +88,12 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
|
|||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = load_target_config(config_path)
|
config = load_target_config(config_path)
|
||||||
targets = build_targets(config, term_registry=sample_term_registry)
|
targets = build_targets(config)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
species = terms.get_term("species")
|
||||||
|
assert species is not None
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
|
|||||||
@ -1,98 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.targets import terms
|
from batdetect2.targets import terms
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import TagInfo
|
||||||
TagInfo,
|
|
||||||
TermRegistry,
|
|
||||||
load_terms_from_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_initialization():
|
|
||||||
registry = TermRegistry()
|
|
||||||
assert registry._terms == {}
|
|
||||||
|
|
||||||
initial_terms = {
|
|
||||||
"test_term": data.Term(name="test", label="Test", definition="test")
|
|
||||||
}
|
|
||||||
registry = TermRegistry(terms=initial_terms)
|
|
||||||
assert registry._terms == initial_terms
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_add_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = data.Term(name="test", label="Test", definition="test")
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
assert registry._terms["test_key"] == term
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_get_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = data.Term(name="test", label="Test", definition="test")
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
retrieved_term = registry.get_term("test_key")
|
|
||||||
assert retrieved_term == term
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_add_custom_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = registry.add_custom_term(
|
|
||||||
"custom_key", name="custom", label="Custom", definition="A custom term"
|
|
||||||
)
|
|
||||||
assert registry._terms["custom_key"] == term
|
|
||||||
assert term.name == "custom"
|
|
||||||
assert term.label == "Custom"
|
|
||||||
assert term.definition == "A custom term"
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_add_duplicate_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = data.Term(name="test", label="Test", definition="test")
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_get_term_not_found():
|
|
||||||
registry = TermRegistry()
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
registry.get_term("non_existent_key")
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_get_keys():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term1 = data.Term(name="test1", label="Test1", definition="test")
|
|
||||||
term2 = data.Term(name="test2", label="Test2", definition="test")
|
|
||||||
registry.add_term("key1", term1)
|
|
||||||
registry.add_term("key2", term2)
|
|
||||||
keys = registry.get_keys()
|
|
||||||
assert set(keys) == {"key1", "key2"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_term_from_key():
|
|
||||||
term = terms.get_term_from_key("event")
|
|
||||||
assert term == terms.call_type
|
|
||||||
|
|
||||||
custom_registry = TermRegistry()
|
|
||||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
|
||||||
custom_registry.add_term("custom_key", custom_term)
|
|
||||||
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
|
|
||||||
assert term == custom_term
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_term_keys():
|
|
||||||
keys = terms.get_term_keys()
|
|
||||||
assert "event" in keys
|
|
||||||
assert "individual" in keys
|
|
||||||
assert terms.GENERIC_CLASS_KEY in keys
|
|
||||||
|
|
||||||
custom_registry = TermRegistry()
|
|
||||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
|
||||||
custom_registry.add_term("custom_key", custom_term)
|
|
||||||
keys = terms.get_term_keys(term_registry=custom_registry)
|
|
||||||
assert "custom_key" in keys
|
|
||||||
|
|
||||||
|
|
||||||
def test_tag_info_and_get_tag_from_info():
|
def test_tag_info_and_get_tag_from_info():
|
||||||
@ -106,74 +15,3 @@ def test_get_tag_from_info_key_not_found():
|
|||||||
tag_info = TagInfo(value="test", key="non_existent_key")
|
tag_info = TagInfo(value="test", key="non_existent_key")
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
terms.get_tag_from_info(tag_info)
|
terms.get_tag_from_info(tag_info)
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config(tmp_path):
|
|
||||||
term_registry = TermRegistry()
|
|
||||||
config_data = {
|
|
||||||
"terms": [
|
|
||||||
{
|
|
||||||
"key": "species",
|
|
||||||
"name": "dwc:scientificName",
|
|
||||||
"label": "Scientific Name",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "my_custom_term",
|
|
||||||
"name": "soundevent:custom_term",
|
|
||||||
"definition": "Describes a specific project attribute",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
config_file = tmp_path / "config.yaml"
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
yaml.dump(config_data, f)
|
|
||||||
|
|
||||||
loaded_terms = load_terms_from_config(
|
|
||||||
config_file,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
assert "species" in loaded_terms
|
|
||||||
assert "my_custom_term" in loaded_terms
|
|
||||||
assert loaded_terms["species"].name == "dwc:scientificName"
|
|
||||||
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config_file_not_found():
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
load_terms_from_config("non_existent_file.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config_validation_error(tmp_path):
|
|
||||||
config_data = {
|
|
||||||
"terms": [
|
|
||||||
{
|
|
||||||
"key": "species",
|
|
||||||
"uri": "dwc:scientificName",
|
|
||||||
"label": 123,
|
|
||||||
}, # Invalid label type
|
|
||||||
]
|
|
||||||
}
|
|
||||||
config_file = tmp_path / "config.yaml"
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
yaml.dump(config_data, f)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
load_terms_from_config(config_file)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config_key_already_exists(tmp_path):
|
|
||||||
config_data = {
|
|
||||||
"terms": [
|
|
||||||
{
|
|
||||||
"key": "event",
|
|
||||||
"uri": "dwc:scientificName",
|
|
||||||
"label": "Scientific Name",
|
|
||||||
}, # Duplicate key
|
|
||||||
]
|
|
||||||
}
|
|
||||||
config_file = tmp_path / "config.yaml"
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
yaml.dump(config_data, f)
|
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
load_terms_from_config(config_file)
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
DeriveTagRule,
|
DeriveTagRule,
|
||||||
@ -11,31 +11,36 @@ from batdetect2.targets import (
|
|||||||
TransformConfig,
|
TransformConfig,
|
||||||
build_transformation_from_config,
|
build_transformation_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import TermRegistry
|
|
||||||
from batdetect2.targets.transform import (
|
from batdetect2.targets.transform import (
|
||||||
DerivationRegistry,
|
DerivationRegistry,
|
||||||
build_transform_from_rule,
|
build_transform_from_rule,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def term_registry():
|
|
||||||
return TermRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def derivation_registry():
|
def derivation_registry():
|
||||||
return DerivationRegistry()
|
return DerivationRegistry()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def term1(term_registry: TermRegistry) -> data.Term:
|
def term1() -> data.Term:
|
||||||
return term_registry.add_custom_term(key="term1")
|
term = data.Term(label="Term 1", definition="unknown", name="test:term1")
|
||||||
|
terms.add_term(term, key="term1", force=True)
|
||||||
|
return term
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def term2(term_registry: TermRegistry) -> data.Term:
|
def term2() -> data.Term:
|
||||||
return term_registry.add_custom_term(key="term2")
|
term = data.Term(label="Term 2", definition="unknown", name="test:term2")
|
||||||
|
terms.add_term(term, key="term2", force=True)
|
||||||
|
return term
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def term3() -> data.Term:
|
||||||
|
term = data.Term(label="Term 3", definition="unknown", name="test:term3")
|
||||||
|
terms.add_term(term, key="term3", force=True)
|
||||||
|
return term
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -47,46 +52,45 @@ def annotation(
|
|||||||
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def annotation2(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
term2: data.Term,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event, tags=[data.Tag(term=term2, value="value2")]
|
||||||
|
)
|
||||||
|
|
||||||
def test_map_value_rule(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
def test_map_value_rule(annotation: data.SoundEventAnnotation):
|
||||||
term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
rule = MapValueRule(
|
rule = MapValueRule(
|
||||||
rule_type="map_value",
|
rule_type="map_value",
|
||||||
source_term_key="term1",
|
source_term_key="term1",
|
||||||
value_mapping={"value1": "value2"},
|
value_mapping={"value1": "value2"},
|
||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
transform_fn = build_transform_from_rule(rule)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
def test_map_value_rule_no_match(
|
def test_map_value_rule_no_match(annotation: data.SoundEventAnnotation):
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
rule = MapValueRule(
|
rule = MapValueRule(
|
||||||
rule_type="map_value",
|
rule_type="map_value",
|
||||||
source_term_key="term1",
|
source_term_key="term1",
|
||||||
value_mapping={"other_value": "value2"},
|
value_mapping={"other_value": "value2"},
|
||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
transform_fn = build_transform_from_rule(rule)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
|
||||||
|
|
||||||
def test_replace_rule(
|
def test_replace_rule(annotation: data.SoundEventAnnotation, term2: data.Term):
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term2: data.Term,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
rule = ReplaceRule(
|
rule = ReplaceRule(
|
||||||
rule_type="replace",
|
rule_type="replace",
|
||||||
original=TagInfo(key="term1", value="value1"),
|
original=TagInfo(key="term1", value="value1"),
|
||||||
replacement=TagInfo(key="term2", value="value2"),
|
replacement=TagInfo(key="term2", value="value2"),
|
||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
transform_fn = build_transform_from_rule(rule)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
assert transformed_annotation.tags[0].term == term2
|
assert transformed_annotation.tags[0].term == term2
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
@ -94,7 +98,7 @@ def test_replace_rule(
|
|||||||
|
|
||||||
def test_replace_rule_no_match(
|
def test_replace_rule_no_match(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
term1: data.Term,
|
||||||
term2: data.Term,
|
term2: data.Term,
|
||||||
):
|
):
|
||||||
rule = ReplaceRule(
|
rule = ReplaceRule(
|
||||||
@ -102,16 +106,19 @@ def test_replace_rule_no_match(
|
|||||||
original=TagInfo(key="term1", value="wrong_value"),
|
original=TagInfo(key="term1", value="wrong_value"),
|
||||||
replacement=TagInfo(key="term2", value="value2"),
|
replacement=TagInfo(key="term2", value="value2"),
|
||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
transform_fn = build_transform_from_rule(rule)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
assert transformed_annotation.tags[0].key == "term1"
|
assert transformed_annotation.tags[0].term == term1
|
||||||
assert transformed_annotation.tags[0].term != term2
|
assert transformed_annotation.tags[0].term != term2
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
|
||||||
|
|
||||||
def test_build_transformation_from_config(
|
def test_build_transformation_from_config(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
annotation2: data.SoundEventAnnotation,
|
||||||
|
term1: data.Term,
|
||||||
|
term2: data.Term,
|
||||||
|
term3: data.Term,
|
||||||
):
|
):
|
||||||
config = TransformConfig(
|
config = TransformConfig(
|
||||||
rules=[
|
rules=[
|
||||||
@ -127,20 +134,20 @@ def test_build_transformation_from_config(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
term_registry.add_custom_term("term2")
|
transform = build_transformation_from_config(config)
|
||||||
term_registry.add_custom_term("term3")
|
|
||||||
transform = build_transformation_from_config(
|
|
||||||
config,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
transformed_annotation = transform(annotation)
|
transformed_annotation = transform(annotation)
|
||||||
assert transformed_annotation.tags[0].key == "term1"
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].term != term2
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
transformed_annotation = transform(annotation2)
|
||||||
|
assert transformed_annotation.tags[0].term == term3
|
||||||
|
assert transformed_annotation.tags[0].value == "value3"
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule(
|
def test_derive_tag_rule(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
derivation_registry: DerivationRegistry,
|
||||||
term1: data.Term,
|
term1: data.Term,
|
||||||
):
|
):
|
||||||
@ -156,7 +163,6 @@ def test_derive_tag_rule(
|
|||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(
|
transform_fn = build_transform_from_rule(
|
||||||
rule,
|
rule,
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
)
|
)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
@ -170,7 +176,6 @@ def test_derive_tag_rule(
|
|||||||
|
|
||||||
def test_derive_tag_rule_keep_source_false(
|
def test_derive_tag_rule_keep_source_false(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
derivation_registry: DerivationRegistry,
|
||||||
term1: data.Term,
|
term1: data.Term,
|
||||||
):
|
):
|
||||||
@ -187,7 +192,6 @@ def test_derive_tag_rule_keep_source_false(
|
|||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(
|
transform_fn = build_transform_from_rule(
|
||||||
rule,
|
rule,
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
)
|
)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
@ -199,7 +203,6 @@ def test_derive_tag_rule_keep_source_false(
|
|||||||
|
|
||||||
def test_derive_tag_rule_target_term(
|
def test_derive_tag_rule_target_term(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
derivation_registry: DerivationRegistry,
|
||||||
term1: data.Term,
|
term1: data.Term,
|
||||||
term2: data.Term,
|
term2: data.Term,
|
||||||
@ -217,7 +220,6 @@ def test_derive_tag_rule_target_term(
|
|||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(
|
transform_fn = build_transform_from_rule(
|
||||||
rule,
|
rule,
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
)
|
)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
@ -231,7 +233,6 @@ def test_derive_tag_rule_target_term(
|
|||||||
|
|
||||||
def test_derive_tag_rule_import_derivation(
|
def test_derive_tag_rule_import_derivation(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
|
||||||
term1: data.Term,
|
term1: data.Term,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
@ -256,7 +257,7 @@ def my_imported_derivation(x: str) -> str:
|
|||||||
derivation_function="temp_derivation.my_imported_derivation",
|
derivation_function="temp_derivation.my_imported_derivation",
|
||||||
import_derivation=True,
|
import_derivation=True,
|
||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
transform_fn = build_transform_from_rule(rule)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
assert len(transformed_annotation.tags) == 2
|
assert len(transformed_annotation.tags) == 2
|
||||||
@ -269,14 +270,14 @@ def my_imported_derivation(x: str) -> str:
|
|||||||
sys.path.remove(str(tmp_path))
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
|
def test_derive_tag_rule_invalid_derivation():
|
||||||
rule = DeriveTagRule(
|
rule = DeriveTagRule(
|
||||||
rule_type="derive_tag",
|
rule_type="derive_tag",
|
||||||
source_term_key="term1",
|
source_term_key="term1",
|
||||||
derivation_function="nonexistent_derivation",
|
derivation_function="nonexistent_derivation",
|
||||||
)
|
)
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
build_transform_from_rule(rule, term_registry=term_registry)
|
build_transform_from_rule(rule)
|
||||||
|
|
||||||
|
|
||||||
def test_build_transform_from_rule_invalid_rule_type():
|
def test_build_transform_from_rule_invalid_rule_type():
|
||||||
@ -291,7 +292,6 @@ def test_build_transform_from_rule_invalid_rule_type():
|
|||||||
|
|
||||||
def test_map_value_rule_target_term(
|
def test_map_value_rule_target_term(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
|
||||||
term2: data.Term,
|
term2: data.Term,
|
||||||
):
|
):
|
||||||
rule = MapValueRule(
|
rule = MapValueRule(
|
||||||
@ -300,7 +300,7 @@ def test_map_value_rule_target_term(
|
|||||||
value_mapping={"value1": "value2"},
|
value_mapping={"value1": "value2"},
|
||||||
target_term_key="term2",
|
target_term_key="term2",
|
||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
transform_fn = build_transform_from_rule(rule)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
assert transformed_annotation.tags[0].term == term2
|
assert transformed_annotation.tags[0].term == term2
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
@ -308,7 +308,6 @@ def test_map_value_rule_target_term(
|
|||||||
|
|
||||||
def test_map_value_rule_target_term_none(
|
def test_map_value_rule_target_term_none(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
|
||||||
term1: data.Term,
|
term1: data.Term,
|
||||||
):
|
):
|
||||||
rule = MapValueRule(
|
rule = MapValueRule(
|
||||||
@ -317,7 +316,7 @@ def test_map_value_rule_target_term_none(
|
|||||||
value_mapping={"value1": "value2"},
|
value_mapping={"value1": "value2"},
|
||||||
target_term_key=None,
|
target_term_key=None,
|
||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
transform_fn = build_transform_from_rule(rule)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
assert transformed_annotation.tags[0].term == term1
|
assert transformed_annotation.tags[0].term == term1
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
@ -325,7 +324,6 @@ def test_map_value_rule_target_term_none(
|
|||||||
|
|
||||||
def test_derive_tag_rule_target_term_none(
|
def test_derive_tag_rule_target_term_none(
|
||||||
annotation: data.SoundEventAnnotation,
|
annotation: data.SoundEventAnnotation,
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
derivation_registry: DerivationRegistry,
|
||||||
term1: data.Term,
|
term1: data.Term,
|
||||||
):
|
):
|
||||||
@ -342,7 +340,6 @@ def test_derive_tag_rule_target_term_none(
|
|||||||
)
|
)
|
||||||
transform_fn = build_transform_from_rule(
|
transform_fn = build_transform_from_rule(
|
||||||
rule,
|
rule,
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
)
|
)
|
||||||
transformed_annotation = transform_fn(annotation)
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|||||||
@ -1,170 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.train.augmentations import (
|
|
||||||
add_echo,
|
|
||||||
mix_audio,
|
|
||||||
)
|
|
||||||
from batdetect2.train.clips import select_subclip
|
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
|
||||||
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
|
|
||||||
|
|
||||||
|
|
||||||
def test_mix_examples(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
recording2 = create_recording()
|
|
||||||
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
|
||||||
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
|
||||||
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
|
||||||
|
|
||||||
example1 = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
example2 = generate_train_example(
|
|
||||||
clip_annotation_2,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
mixed = mix_audio(
|
|
||||||
example1,
|
|
||||||
example2,
|
|
||||||
weight=0.3,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
|
||||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
|
||||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
|
||||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
|
||||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
|
||||||
def test_mix_examples_of_different_durations(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
duration1: float,
|
|
||||||
duration2: float,
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
recording2 = create_recording()
|
|
||||||
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
|
|
||||||
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
|
|
||||||
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
|
||||||
|
|
||||||
example1 = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
example2 = generate_train_example(
|
|
||||||
clip_annotation_2,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
mixed = mix_audio(
|
|
||||||
example1,
|
|
||||||
example2,
|
|
||||||
weight=0.3,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
|
||||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
|
||||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
|
||||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_echo(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
|
|
||||||
original = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
with_echo = add_echo(
|
|
||||||
original,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
delay=0.1,
|
|
||||||
weight=0.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert with_echo.spectrogram.shape == original.spectrogram.shape
|
|
||||||
torch.testing.assert_close(
|
|
||||||
with_echo.size_heatmap,
|
|
||||||
original.size_heatmap,
|
|
||||||
atol=0,
|
|
||||||
rtol=0,
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
with_echo.class_heatmap,
|
|
||||||
original.class_heatmap,
|
|
||||||
atol=0,
|
|
||||||
rtol=0,
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
with_echo.detection_heatmap,
|
|
||||||
original.detection_heatmap,
|
|
||||||
atol=0,
|
|
||||||
rtol=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_selected_random_subclip_has_the_correct_width(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
|
|
||||||
original = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
subclip = select_subclip(
|
|
||||||
original,
|
|
||||||
input_samplerate=256_000,
|
|
||||||
output_samplerate=1000,
|
|
||||||
start=0,
|
|
||||||
duration=0.512,
|
|
||||||
)
|
|
||||||
assert subclip.spectrogram.shape[1] == 512
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.train import generate_train_example
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
ClipLabeller,
|
|
||||||
ClipperProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_clip_size_is_correct(
|
|
||||||
sample_clipper: ClipperProtocol,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
):
|
|
||||||
example = generate_train_example(
|
|
||||||
clip_annotation=clip_annotation,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip, _, _ = sample_clipper(example)
|
|
||||||
assert clip.spectrogram.shape == (1, 128, 256)
|
|
||||||
@ -56,7 +56,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
torch.rand([100, 100]),
|
torch.rand([1, 100, 100]),
|
||||||
min_freq=0,
|
min_freq=0,
|
||||||
max_freq=100,
|
max_freq=100,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
assert size_heatmap[1, 10, 10] == 20
|
assert size_heatmap[1, 10, 10] == 20
|
||||||
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
||||||
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
||||||
assert detection_heatmap[10, 10] == 1.0
|
assert detection_heatmap[0, 10, 10] == 1.0
|
||||||
|
|||||||
@ -4,14 +4,16 @@ import lightning as L
|
|||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.models import build_model
|
||||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
|
model = build_model()
|
||||||
config = FullTrainingConfig()
|
config = FullTrainingConfig()
|
||||||
return build_training_module(config)
|
return build_training_module(model, config=config)
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
@ -32,14 +34,14 @@ def test_can_save_checkpoint(
|
|||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
|
||||||
wav = torch.tensor(sample_audio_loader.load_clip(clip))
|
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
|
||||||
|
|
||||||
spec1 = module.model.preprocessor(wav)
|
spec1 = module.model.preprocessor(wav)
|
||||||
spec2 = recovered.model.preprocessor(wav)
|
spec2 = recovered.model.preprocessor(wav)
|
||||||
|
|
||||||
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||||
|
|
||||||
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
|
output1 = module(spec1.unsqueeze(0))
|
||||||
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
|
output2 = recovered(spec2.unsqueeze(0))
|
||||||
|
|
||||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|||||||
@ -1,230 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.terms import get_term
|
|
||||||
|
|
||||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
|
||||||
from batdetect2.typing import ModelOutput
|
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def build_from_config(
|
|
||||||
create_temp_yaml,
|
|
||||||
):
|
|
||||||
def build(yaml_content):
|
|
||||||
config_path = create_temp_yaml(yaml_content)
|
|
||||||
|
|
||||||
targets_config = load_target_config(config_path, field="targets")
|
|
||||||
preprocessing_config = load_preprocessing_config(
|
|
||||||
config_path,
|
|
||||||
field="preprocessing",
|
|
||||||
)
|
|
||||||
labels_config = load_label_config(config_path, field="labels")
|
|
||||||
postprocessing_config = load_postprocess_config(
|
|
||||||
config_path,
|
|
||||||
field="postprocessing",
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = build_targets(targets_config)
|
|
||||||
preprocessor = build_preprocessor(preprocessing_config)
|
|
||||||
labeller = build_clip_labeler(
|
|
||||||
targets=targets,
|
|
||||||
config=labels_config,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=postprocessing_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return targets, preprocessor, labeller, postprocessor
|
|
||||||
|
|
||||||
return build
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object(
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
build_from_config,
|
|
||||||
recording,
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
labels:
|
|
||||||
targets:
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: bottom-left
|
|
||||||
classes:
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
preprocessing:
|
|
||||||
"""
|
|
||||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
|
||||||
se1 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
|
||||||
|
|
||||||
encoded = generate_train_example(
|
|
||||||
clip_annotation,
|
|
||||||
sample_audio_loader,
|
|
||||||
preprocessor,
|
|
||||||
labeller,
|
|
||||||
)
|
|
||||||
predictions = postprocessor.get_predictions(
|
|
||||||
ModelOutput(
|
|
||||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
|
||||||
0
|
|
||||||
),
|
|
||||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
|
||||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
|
||||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
|
||||||
),
|
|
||||||
[clip],
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
assert isinstance(predictions, data.ClipPrediction)
|
|
||||||
assert len(predictions.sound_events) == 1
|
|
||||||
|
|
||||||
recovered = predictions.sound_events[0]
|
|
||||||
assert recovered.sound_event.geometry is not None
|
|
||||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
|
||||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
|
||||||
recovered.sound_event.geometry.coordinates
|
|
||||||
)
|
|
||||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
|
||||||
geometry.coordinates
|
|
||||||
)
|
|
||||||
|
|
||||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
|
||||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
|
||||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
|
||||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
|
||||||
|
|
||||||
assert len(recovered.tags) == 2
|
|
||||||
|
|
||||||
predicted_species_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_species_tag is not None
|
|
||||||
assert predicted_species_tag.score == 1
|
|
||||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
|
||||||
|
|
||||||
predicted_order_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_order_tag is not None
|
|
||||||
assert predicted_order_tag.score == 1
|
|
||||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
build_from_config,
|
|
||||||
recording,
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
labels:
|
|
||||||
targets:
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: bottom-left
|
|
||||||
classes:
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
- name: myomyo
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
preprocessing:
|
|
||||||
"""
|
|
||||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
|
||||||
se1 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
|
||||||
|
|
||||||
encoded = generate_train_example(
|
|
||||||
clip_annotation,
|
|
||||||
sample_audio_loader,
|
|
||||||
preprocessor,
|
|
||||||
labeller,
|
|
||||||
)
|
|
||||||
predictions = postprocessor.get_predictions(
|
|
||||||
ModelOutput(
|
|
||||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
|
||||||
0
|
|
||||||
),
|
|
||||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
|
||||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
|
||||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
|
||||||
),
|
|
||||||
[clip],
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
assert isinstance(predictions, data.ClipPrediction)
|
|
||||||
assert len(predictions.sound_events) == 1
|
|
||||||
|
|
||||||
recovered = predictions.sound_events[0]
|
|
||||||
assert recovered.sound_event.geometry is not None
|
|
||||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
|
||||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
|
||||||
recovered.sound_event.geometry.coordinates
|
|
||||||
)
|
|
||||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
|
||||||
geometry.coordinates
|
|
||||||
)
|
|
||||||
|
|
||||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
|
||||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
|
||||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
|
||||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
|
||||||
|
|
||||||
assert len(recovered.tags) == 2
|
|
||||||
|
|
||||||
predicted_species_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_species_tag is not None
|
|
||||||
assert predicted_species_tag.score == 1
|
|
||||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
|
||||||
|
|
||||||
predicted_order_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_order_tag is not None
|
|
||||||
assert predicted_order_tag.score == 1
|
|
||||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
|
||||||
Loading…
Reference in New Issue
Block a user