Remove stale tests

This commit is contained in:
mbsantiago 2025-09-07 11:03:46 +01:00
parent 709b6355c2
commit cf6d0d1ccc
18 changed files with 138 additions and 1215 deletions

View File

@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
from pydantic import BaseModel, Field
from soundevent import data
from batdetect2.targets import get_term_from_key
PathLike = Union[Path, str, os.PathLike]
__all__ = []
@ -92,15 +90,15 @@ def annotation_to_sound_event(
sound_event=sound_event,
tags=[
data.Tag(
term=get_term_from_key(label_key),
key=label_key, # type: ignore
value=annotation.label,
),
data.Tag(
term=get_term_from_key(event_key),
key=event_key, # type: ignore
value=annotation.event,
),
data.Tag(
term=get_term_from_key(individual_key),
key=individual_key, # type: ignore
value=str(annotation.individual),
),
],
@ -125,7 +123,7 @@ def file_annotation_to_clip(
time_expansion=file_annotation.time_exp,
tags=[
data.Tag(
term=get_term_from_key(label_key),
key=label_key, # type: ignore
value=file_annotation.label,
)
],
@ -157,7 +155,8 @@ def file_annotation_to_clip_annotation(
notes=notes,
tags=[
data.Tag(
term=get_term_from_key(label_key), value=file_annotation.label
key=label_key, # type: ignore
value=file_annotation.label,
)
],
sound_events=[

View File

@ -68,7 +68,10 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
from batdetect2.typing.postprocess import (
DetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
@ -122,7 +125,7 @@ class Model(LightningModule):
self.targets = targets
self.save_hyperparameters()
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)

View File

@ -57,14 +57,9 @@ from batdetect2.targets.rois import (
)
from batdetect2.targets.terms import (
TagInfo,
TermInfo,
TermRegistry,
call_type,
default_term_registry,
get_tag_from_info,
get_term_from_key,
individual,
register_term,
)
from batdetect2.targets.transform import (
DerivationRegistry,
@ -100,7 +95,6 @@ __all__ = [
"TargetClass",
"TargetConfig",
"Targets",
"TermInfo",
"TransformConfig",
"build_generic_class_tags",
"build_roi_mapper",
@ -112,7 +106,6 @@ __all__ = [
"get_class_names_from_config",
"get_derivation",
"get_tag_from_info",
"get_term_from_key",
"individual",
"load_classes_config",
"load_decoder_from_config",
@ -123,7 +116,6 @@ __all__ = [
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
"register_term",
]
@ -534,7 +526,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
def build_targets(
config: Optional[TargetConfig] = None,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
@ -550,9 +541,6 @@ def build_targets(
----------
config : TargetConfig
The loaded and validated unified target configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance to use for resolving term keys. Defaults
to the global `batdetect2.targets.terms.term_registry`.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
@ -578,25 +566,15 @@ def build_targets(
)
filter_fn = (
build_sound_event_filter(
config.filtering,
term_registry=term_registry,
)
build_sound_event_filter(config.filtering)
if config.filtering
else None
)
encode_fn = build_sound_event_encoder(
config.classes,
term_registry=term_registry,
)
decode_fn = build_sound_event_decoder(
config.classes,
term_registry=term_registry,
)
encode_fn = build_sound_event_encoder(config.classes)
decode_fn = build_sound_event_decoder(config.classes)
transform_fn = (
build_transformation_from_config(
config.transforms,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
if config.transforms
@ -604,10 +582,7 @@ def build_targets(
)
roi_mapper = build_roi_mapper(config.roi)
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(
config.classes,
term_registry=term_registry,
)
generic_class_tags = build_generic_class_tags(config.classes)
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classes.classes
@ -629,7 +604,6 @@ def build_targets(
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.
@ -645,8 +619,6 @@ def load_targets(
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
term_registry : TermRegistry, optional
The TermRegistry instance to use. Defaults to the global default.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
@ -670,11 +642,7 @@ def load_targets(
config_path,
field=field,
)
return build_targets(
config,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
return build_targets(config, derivation_registry=derivation_registry)
def iterate_encoded_sound_events(

View File

@ -10,8 +10,6 @@ from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import (
GENERIC_CLASS_KEY,
TagInfo,
TermRegistry,
default_term_registry,
get_tag_from_info,
)
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
@ -295,10 +293,7 @@ def _encode_with_multiple_classifiers(
return None
def build_sound_event_encoder(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder:
def build_sound_event_encoder(config: ClassesConfig) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration.
The returned encoder function iterates through the class definitions in the
@ -333,8 +328,7 @@ def build_sound_event_encoder(
partial(
is_target_class,
tags={
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in class_info.tags
get_tag_from_info(tag_info) for tag_info in class_info.tags
},
match_all=class_info.match_type == "all",
),
@ -391,7 +385,6 @@ def _decode_class(
def build_sound_event_decoder(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration.
@ -433,8 +426,7 @@ def build_sound_event_decoder(
else class_info.tags
)
mapping[class_info.name] = [
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in tags_to_use
get_tag_from_info(tag_info) for tag_info in tags_to_use
]
return partial(
@ -444,10 +436,7 @@ def build_sound_event_decoder(
)
def build_generic_class_tags(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
) -> List[data.Tag]:
def build_generic_class_tags(config: ClassesConfig) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config.
Converts the list of `TagInfo` objects defined in `config.generic_class`
@ -472,10 +461,7 @@ def build_generic_class_tags(
If a term key specified in `config.generic_class` is not found in the
provided `term_registry`.
"""
return [
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in config.generic_class
]
return [get_tag_from_info(tag_info) for tag_info in config.generic_class]
def load_classes_config(
@ -509,9 +495,7 @@ def load_classes_config(
def load_encoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
path: data.PathLike, field: Optional[str] = None
) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file.
@ -546,13 +530,12 @@ def load_encoder_from_config(
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_sound_event_encoder(config, term_registry=term_registry)
return build_sound_event_encoder(config)
def load_decoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file.
@ -594,6 +577,5 @@ def load_decoder_from_config(
config = load_classes_config(path, field=field)
return build_sound_event_decoder(
config,
term_registry=term_registry,
raise_on_unmapped=raise_on_unmapped,
)

View File

@ -8,8 +8,6 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
default_term_registry,
get_tag_from_info,
)
from batdetect2.typing.targets import SoundEventFilter
@ -146,10 +144,7 @@ def equal_tags(
return tags == sound_event_tags
def build_filter_from_rule(
rule: FilterRule,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule.
Parameters
@ -168,10 +163,7 @@ def build_filter_from_rule(
ValueError
If the rule contains an invalid `match_type`.
"""
tag_set = {
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in rule.tags
}
tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags}
if rule.match_type == "any":
return partial(has_any_tag, tags=tag_set)
@ -235,7 +227,6 @@ class FilterConfig(BaseConfig):
def build_sound_event_filter(
config: FilterConfig,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object.
@ -252,10 +243,7 @@ def build_sound_event_filter(
SoundEventFilter
A single callable filter function that applies all defined rules.
"""
filters = [
build_filter_from_rule(rule, term_registry=term_registry)
for rule in config.rules
]
filters = [build_filter_from_rule(rule) for rule in config.rules]
return partial(_passes_all_filters, filters=filters)
@ -281,9 +269,7 @@ def load_filter_config(
def load_filter_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
path: data.PathLike, field: Optional[str] = None,
) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function.
@ -304,4 +290,4 @@ def load_filter_from_config(
The final merged filter function ready to be used.
"""
config = load_filter_config(path=path, field=field)
return build_sound_event_filter(config, term_registry=term_registry)
return build_sound_event_filter(config)

View File

@ -1,33 +1,22 @@
"""Manages the vocabulary (Terms and Tags) for defining training targets.
"""Manages the vocabulary for defining training targets.
This module provides the necessary tools to declare, register, and manage the
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
sub-package. It establishes a consistent vocabulary for filtering,
transforming, and classifying sound events based on their annotations (Tags).
The core component is the `TermRegistry`, which maps unique string keys
(aliases) to specific `Term` definitions. This allows users to refer to complex
terms using simple, consistent keys in configuration files and code.
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
programmatically, or loaded from external configuration files (e.g., YAML).
Terms can be pre-defined, loaded from the `soundevent.terms` library or defined
programmatically.
"""
from collections.abc import Mapping
from inspect import getmembers
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
from soundevent import data, terms
from batdetect2.configs import load_config
__all__ = [
"call_type",
"individual",
"data_source",
"get_tag_from_info",
"TermInfo",
"TagInfo",
]
@ -98,255 +87,6 @@ terms.register_term_set(
)
class TermRegistry(Mapping[str, data.Term]):
"""Manages a registry mapping unique keys to Term definitions.
This class acts as the central repository for the vocabulary of terms
used within the target definition process. It allows registering terms
with simple string keys and retrieving them consistently.
"""
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
"""Initializes the TermRegistry.
Parameters
----------
terms : dict[str, soundevent.data.Term], optional
An optional dictionary of initial key-to-Term mappings
to populate the registry with. Defaults to an empty registry.
"""
self._terms: Dict[str, data.Term] = terms or {}
def __getitem__(self, key: str) -> data.Term:
return self._terms[key]
def __len__(self) -> int:
return len(self._terms)
def __iter__(self):
return iter(self._terms)
def add_term(self, key: str, term: data.Term) -> None:
"""Adds a Term object to the registry with the specified key.
Parameters
----------
key : str
The unique string key to associate with the term.
term : soundevent.data.Term
The soundevent.data.Term object to register.
Raises
------
KeyError
If a term with the provided key already exists in the
registry.
"""
if key in self._terms:
raise KeyError("A term with the provided key already exists.")
self._terms[key] = term
def get_term(self, key: str) -> data.Term:
"""Retrieves a registered term by its unique key.
Parameters
----------
key : str
The unique string key of the term to retrieve.
Returns
-------
soundevent.data.Term
The corresponding soundevent.data.Term object.
Raises
------
KeyError
If no term with the specified key is found, with a
helpful message suggesting listing available keys.
"""
try:
return self._terms[key]
except KeyError as err:
raise KeyError(
"No term found for key "
f"'{key}'. Ensure it is registered or loaded. "
f"Available keys: {', '.join(self.get_keys())}"
) from err
def add_custom_term(
self,
key: str,
name: Optional[str] = None,
uri: Optional[str] = None,
label: Optional[str] = None,
definition: Optional[str] = None,
) -> data.Term:
"""Creates a new Term from attributes and adds it to the registry.
This is useful for defining terms directly in code or when loading
from configuration files where only attributes are provided.
If optional fields (`name`, `label`, `definition`) are not provided,
reasonable defaults are used (`key` for name/label, "Unknown" for
definition).
Parameters
----------
key : str
The unique string key for the new term.
name : str, optional
The name for the new term (defaults to `key`).
uri : str, optional
The URI for the new term (optional).
label : str, optional
The display label for the new term (defaults to `key`).
definition : str, optional
The definition for the new term (defaults to "Unknown").
Returns
-------
soundevent.data.Term
The newly created and registered soundevent.data.Term object.
Raises
------
KeyError
If a term with the provided key already exists.
"""
term = data.Term(
name=name or key,
label=label or key,
uri=uri,
definition=definition or "Unknown",
)
self.add_term(key, term)
return term
def get_keys(self) -> List[str]:
"""Returns a list of all keys currently registered.
Returns
-------
list[str]
A list of strings representing the keys of all registered terms.
"""
return list(self._terms.keys())
def get_terms(self) -> List[data.Term]:
"""Returns a list of all registered terms.
Returns
-------
list[soundevent.data.Term]
A list containing all registered Term objects.
"""
return list(self._terms.values())
def remove_key(self, key: str) -> None:
del self._terms[key]
default_term_registry = TermRegistry(
terms=dict(
[
*getmembers(terms, lambda x: isinstance(x, data.Term)),
("event", call_type),
("species", terms.scientific_name),
("individual", individual),
("data_source", data_source),
(GENERIC_CLASS_KEY, generic_class),
]
)
)
"""The default, globally accessible TermRegistry instance.
It is pre-populated with standard terms from `soundevent.terms` and common
terms defined in this module (`call_type`, `individual`, `generic_class`).
Functions in this module use this registry by default unless another instance
is explicitly passed.
"""
def get_term_from_key(
key: str,
term_registry: Optional[TermRegistry] = None,
) -> data.Term:
"""Convenience function to retrieve a term by key from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
key : str
The unique key of the term to retrieve.
term_registry : TermRegistry, optional
The TermRegistry instance to search in. Defaults to the global
`registry`.
Returns
-------
soundevent.data.Term
The corresponding soundevent.data.Term object.
Raises
------
KeyError
If the key is not found in the specified registry.
"""
term = terms.get_term(key)
if term:
return term
term_registry = term_registry or default_term_registry
return term_registry.get_term(key)
def get_term_keys(
term_registry: TermRegistry = default_term_registry,
) -> List[str]:
"""Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
term_registry : TermRegistry, optional
The TermRegistry instance to query. Defaults to the global `registry`.
Returns
-------
list[str]
A list of strings representing the keys of all registered terms.
"""
return term_registry.get_keys()
def get_terms(
term_registry: TermRegistry = default_term_registry,
) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
term_registry : TermRegistry, optional
The TermRegistry instance to query. Defaults to the global `registry`.
Returns
-------
list[soundevent.data.Term]
A list containing all registered Term objects.
"""
return term_registry.get_terms()
class TagInfo(BaseModel):
"""Represents information needed to define a specific Tag.
@ -360,31 +100,25 @@ class TagInfo(BaseModel):
value : str
The value of the tag (e.g., "Myotis myotis", "Echolocation").
key : str, default="class"
The key (alias) of the term associated with this tag, as
registered in the TermRegistry. Defaults to "class", implying
it represents a classification target label by default.
The key (alias) of the term associated with this tag. Defaults to
"class", implying it represents a classification target label by
default.
"""
value: str
key: str = GENERIC_CLASS_KEY
def get_tag_from_info(
tag_info: TagInfo,
term_registry: Optional[TermRegistry] = None,
) -> data.Tag:
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
Looks up the term using the key in the provided `tag_info` from the
specified registry and constructs a Tag object.
Looks up the term using the key in the provided `tag_info` and constructs a
Tag object.
Parameters
----------
tag_info : TagInfo
The TagInfo object containing the value and term key.
term_registry : TermRegistry, optional
The TermRegistry instance to use for term lookup. Defaults to the
global `registry`.
Returns
-------
@ -394,132 +128,11 @@ def get_tag_from_info(
Raises
------
KeyError
If the term key specified in `tag_info.key` is not found
in the registry.
If the term key specified in `tag_info.key` is not found.
"""
term_registry = term_registry or default_term_registry
term = get_term_from_key(tag_info.key, term_registry=term_registry)
term = terms.get_term(tag_info.key)
if not term:
raise KeyError(f"Key {tag_info.key} not found")
return data.Tag(term=term, value=tag_info.value)
class TermInfo(BaseModel):
"""Represents the definition of a Term within a configuration file.
This model allows users to define custom terms directly in configuration
files (e.g., YAML) which can then be loaded into the TermRegistry.
It mirrors the parameters of `TermRegistry.add_custom_term`.
Attributes
----------
key : str
The unique key (alias) that will be used to register and
reference this term.
label : str, optional
The optional display label for the term. Defaults to `key`
if not provided during registration.
name : str, optional
The optional formal name for the term. Defaults to `key`
if not provided during registration.
uri : str, optional
The optional URI identifying the term (e.g., from a standard
vocabulary).
definition : str, optional
The optional textual definition of the term. Defaults to
"Unknown" if not provided during registration.
"""
key: str
label: Optional[str] = None
name: Optional[str] = None
uri: Optional[str] = None
definition: Optional[str] = None
class TermConfig(BaseModel):
"""Pydantic schema for loading a list of term definitions from config.
This model typically corresponds to a section in a configuration file
(e.g., YAML) containing a list of term definitions to be registered.
Attributes
----------
terms : list[TermInfo]
A list of TermInfo objects, each defining a term to be
registered. Defaults to an empty list.
Examples
--------
Example YAML structure:
```yaml
terms:
- key: species
uri: dwc:scientificName
label: Scientific Name
- key: my_custom_term
name: My Custom Term
definition: Describes a specific project attribute.
# ... more TermInfo definitions
```
"""
terms: List[TermInfo] = Field(default_factory=list)
def load_terms_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them.
Parses a configuration file (e.g., YAML) using the TermConfig schema,
extracts the list of TermInfo definitions, and adds each one as a
custom term to the specified TermRegistry instance.
Parameters
----------
path : data.PathLike
The path to the configuration file.
field : str, optional
Optional key indicating a specific section within the config
file where the 'terms' list is located. If None, expects the
list directly at the top level or within a structure matching
TermConfig schema.
term_registry : TermRegistry, optional
The TermRegistry instance to add the loaded terms to. Defaults to
the global `registry`.
Returns
-------
dict[str, soundevent.data.Term]
A dictionary mapping the keys of the newly added terms to their
corresponding Term objects.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TermConfig schema.
KeyError
If a term key loaded from the config conflicts with a key
already present in the registry.
"""
data = load_config(path, schema=TermConfig, field=field)
return {
info.key: term_registry.add_custom_term(
info.key,
name=info.name,
uri=info.uri,
label=info.label,
definition=info.definition,
)
for info in data.terms
}
def register_term(
key: str, term: data.Term, registry: TermRegistry = default_term_registry
) -> None:
registry.add_term(key, term)

View File

@ -12,15 +12,10 @@ from typing import (
)
from pydantic import Field
from soundevent import data
from soundevent import data, terms
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
get_term_from_key,
)
from batdetect2.targets.terms import TagInfo, get_tag_from_info
__all__ = [
"DerivationRegistry",
@ -466,7 +461,6 @@ TranformationRule = Annotated[
def build_transform_from_rule(
rule: TranformationRule,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
@ -497,29 +491,21 @@ def build_transform_from_rule(
If dynamic import of a derivation function fails.
"""
if rule.rule_type == "replace":
source = get_tag_from_info(
rule.original,
term_registry=term_registry,
)
target = get_tag_from_info(
rule.replacement,
term_registry=term_registry,
)
source = get_tag_from_info(rule.original)
target = get_tag_from_info(rule.replacement)
return partial(replace_tag_transform, source=source, target=target)
if rule.rule_type == "derive_tag":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
source_term = terms.get_term(rule.source_term_key)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
terms.get_term(rule.target_term_key)
if rule.target_term_key
else source_term
)
if source_term is None or target_term is None:
raise KeyError("Terms not found")
derivation = get_derivation(
key=rule.derivation_function,
import_derivation=rule.import_derivation,
@ -534,18 +520,16 @@ def build_transform_from_rule(
)
if rule.rule_type == "map_value":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
source_term = terms.get_term(rule.source_term_key)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
terms.get_term(rule.target_term_key)
if rule.target_term_key
else source_term
)
if source_term is None or target_term is None:
raise KeyError("Terms not found")
return partial(
map_value_transform,
source_term=source_term,
@ -555,6 +539,7 @@ def build_transform_from_rule(
# Handle unknown rule type
valid_options = ["replace", "derive_tag", "map_value"]
# Should be caught by Pydantic validation, but good practice
raise ValueError(
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
@ -565,7 +550,6 @@ def build_transform_from_rule(
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
@ -591,7 +575,6 @@ def build_transformation_from_config(
build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
for rule in config.rules
]
@ -640,7 +623,6 @@ def load_transformation_from_config(
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
@ -678,7 +660,6 @@ def load_transformation_from_config(
return build_transformation_from_config(
config,
derivation_registry=derivation_registry,
term_registry=term_registry,
)

View File

@ -31,7 +31,7 @@ class TrainingModule(L.LightningModule):
self.save_hyperparameters(logger=False)
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.model(spec)
return self.model.detector(spec)
def training_step(self, batch: TrainExample):
outputs = self.model.detector(batch.spec)

View File

@ -99,7 +99,7 @@ def train(
module = build_training_module(
model,
config,
batches_per_epoch=len(train_dataloader),
t_max=config.train.t_max * len(train_dataloader),
)
logger.info("Starting main training loop...")
@ -113,15 +113,16 @@ def train(
def build_training_module(
model: Model,
config: FullTrainingConfig,
batches_per_epoch: int,
config: Optional[FullTrainingConfig] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
loss = build_loss(config=config.train.loss)
return TrainingModule(
model=model,
loss=loss,
learning_rate=config.train.learning_rate,
t_max=config.train.t_max * batches_per_epoch,
t_max=t_max,
)

View File

@ -15,7 +15,6 @@ from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import (
TargetConfig,
TermRegistry,
build_targets,
call_type,
)
@ -355,18 +354,6 @@ def create_annotation_project():
return factory
@pytest.fixture
def sample_term_registry() -> TermRegistry:
"""Fixture for a sample TermRegistry."""
registry = TermRegistry()
registry.add_custom_term("class")
registry.add_custom_term("order")
registry.add_custom_term("species")
registry.add_custom_term("call_type")
registry.add_custom_term("quality")
return registry
@pytest.fixture
def sample_preprocessor() -> PreprocessorProtocol:
return build_preprocessor()
@ -399,7 +386,6 @@ def pippip_tag() -> TagInfo:
@pytest.fixture
def sample_target_config(
sample_term_registry: TermRegistry,
bat_tag: TagInfo,
noise_tag: TagInfo,
myomyo_tag: TagInfo,
@ -422,12 +408,8 @@ def sample_target_config(
@pytest.fixture
def sample_targets(
sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
) -> TargetProtocol:
return build_targets(
sample_target_config,
term_registry=sample_term_registry,
)
return build_targets(sample_target_config)
@pytest.fixture
@ -443,10 +425,8 @@ def sample_labeller(
@pytest.fixture
def sample_clipper(
sample_preprocessor: PreprocessorProtocol,
) -> ClipperProtocol:
return build_clipper(preprocessor=sample_preprocessor)
def sample_clipper() -> ClipperProtocol:
return build_clipper()
@pytest.fixture

View File

@ -1,16 +1,14 @@
from collections.abc import Callable
from pathlib import Path
from soundevent import data
from soundevent import data, terms
from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
def test_can_override_default_roi_mapper_per_class(
create_temp_yaml: Callable[..., Path],
recording: data.Recording,
sample_term_registry,
):
yaml_content = """
roi:
@ -36,11 +34,13 @@ def test_can_override_default_roi_mapper_per_class(
config_path = create_temp_yaml(yaml_content)
config = load_target_config(config_path)
targets = build_targets(config, term_registry=sample_term_registry)
targets = build_targets(config)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
species = terms.get_term("species")
assert species is not None
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
@ -62,7 +62,6 @@ def test_can_override_default_roi_mapper_per_class(
# TODO: rename this test function
def test_roi_is_recovered_roundtrip_even_with_overriders(
create_temp_yaml,
sample_term_registry,
recording,
):
yaml_content = """
@ -89,11 +88,12 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
config_path = create_temp_yaml(yaml_content)
config = load_target_config(config_path)
targets = build_targets(config, term_registry=sample_term_registry)
targets = build_targets(config)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
species = terms.get_term("species")
assert species is not None
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],

View File

@ -1,98 +1,7 @@
import pytest
import yaml
from soundevent import data
from batdetect2.targets import terms
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
load_terms_from_config,
)
def test_term_registry_initialization():
registry = TermRegistry()
assert registry._terms == {}
initial_terms = {
"test_term": data.Term(name="test", label="Test", definition="test")
}
registry = TermRegistry(terms=initial_terms)
assert registry._terms == initial_terms
def test_term_registry_add_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
assert registry._terms["test_key"] == term
def test_term_registry_get_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
retrieved_term = registry.get_term("test_key")
assert retrieved_term == term
def test_term_registry_add_custom_term():
registry = TermRegistry()
term = registry.add_custom_term(
"custom_key", name="custom", label="Custom", definition="A custom term"
)
assert registry._terms["custom_key"] == term
assert term.name == "custom"
assert term.label == "Custom"
assert term.definition == "A custom term"
def test_term_registry_add_duplicate_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
with pytest.raises(KeyError):
registry.add_term("test_key", term)
def test_term_registry_get_term_not_found():
registry = TermRegistry()
with pytest.raises(KeyError):
registry.get_term("non_existent_key")
def test_term_registry_get_keys():
registry = TermRegistry()
term1 = data.Term(name="test1", label="Test1", definition="test")
term2 = data.Term(name="test2", label="Test2", definition="test")
registry.add_term("key1", term1)
registry.add_term("key2", term2)
keys = registry.get_keys()
assert set(keys) == {"key1", "key2"}
def test_get_term_from_key():
term = terms.get_term_from_key("event")
assert term == terms.call_type
custom_registry = TermRegistry()
custom_term = data.Term(name="custom", label="Custom", definition="test")
custom_registry.add_term("custom_key", custom_term)
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
assert term == custom_term
def test_get_term_keys():
keys = terms.get_term_keys()
assert "event" in keys
assert "individual" in keys
assert terms.GENERIC_CLASS_KEY in keys
custom_registry = TermRegistry()
custom_term = data.Term(name="custom", label="Custom", definition="test")
custom_registry.add_term("custom_key", custom_term)
keys = terms.get_term_keys(term_registry=custom_registry)
assert "custom_key" in keys
from batdetect2.targets.terms import TagInfo
def test_tag_info_and_get_tag_from_info():
@ -106,74 +15,3 @@ def test_get_tag_from_info_key_not_found():
tag_info = TagInfo(value="test", key="non_existent_key")
with pytest.raises(KeyError):
terms.get_tag_from_info(tag_info)
def test_load_terms_from_config(tmp_path):
term_registry = TermRegistry()
config_data = {
"terms": [
{
"key": "species",
"name": "dwc:scientificName",
"label": "Scientific Name",
},
{
"key": "my_custom_term",
"name": "soundevent:custom_term",
"definition": "Describes a specific project attribute",
},
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loaded_terms = load_terms_from_config(
config_file,
term_registry=term_registry,
)
assert "species" in loaded_terms
assert "my_custom_term" in loaded_terms
assert loaded_terms["species"].name == "dwc:scientificName"
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
def test_load_terms_from_config_file_not_found():
with pytest.raises(FileNotFoundError):
load_terms_from_config("non_existent_file.yaml")
def test_load_terms_from_config_validation_error(tmp_path):
config_data = {
"terms": [
{
"key": "species",
"uri": "dwc:scientificName",
"label": 123,
}, # Invalid label type
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
with pytest.raises(ValueError):
load_terms_from_config(config_file)
def test_load_terms_from_config_key_already_exists(tmp_path):
config_data = {
"terms": [
{
"key": "event",
"uri": "dwc:scientificName",
"label": "Scientific Name",
}, # Duplicate key
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
with pytest.raises(KeyError):
load_terms_from_config(config_file)

View File

@ -1,7 +1,7 @@
from pathlib import Path
import pytest
from soundevent import data
from soundevent import data, terms
from batdetect2.targets import (
DeriveTagRule,
@ -11,31 +11,36 @@ from batdetect2.targets import (
TransformConfig,
build_transformation_from_config,
)
from batdetect2.targets.terms import TermRegistry
from batdetect2.targets.transform import (
DerivationRegistry,
build_transform_from_rule,
)
@pytest.fixture
def term_registry():
return TermRegistry()
@pytest.fixture
def derivation_registry():
return DerivationRegistry()
@pytest.fixture
def term1(term_registry: TermRegistry) -> data.Term:
return term_registry.add_custom_term(key="term1")
def term1() -> data.Term:
term = data.Term(label="Term 1", definition="unknown", name="test:term1")
terms.add_term(term, key="term1", force=True)
return term
@pytest.fixture
def term2(term_registry: TermRegistry) -> data.Term:
return term_registry.add_custom_term(key="term2")
def term2() -> data.Term:
term = data.Term(label="Term 2", definition="unknown", name="test:term2")
terms.add_term(term, key="term2", force=True)
return term
@pytest.fixture
def term3() -> data.Term:
term = data.Term(label="Term 3", definition="unknown", name="test:term3")
terms.add_term(term, key="term3", force=True)
return term
@pytest.fixture
@ -47,46 +52,45 @@ def annotation(
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
)
@pytest.fixture
def annotation2(
sound_event: data.SoundEvent,
term2: data.Term,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event, tags=[data.Tag(term=term2, value="value2")]
)
def test_map_value_rule(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
def test_map_value_rule(annotation: data.SoundEventAnnotation):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value2"
def test_map_value_rule_no_match(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
def test_map_value_rule_no_match(annotation: data.SoundEventAnnotation):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"other_value": "value2"},
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value1"
def test_replace_rule(
annotation: data.SoundEventAnnotation,
term2: data.Term,
term_registry: TermRegistry,
):
def test_replace_rule(annotation: data.SoundEventAnnotation, term2: data.Term):
rule = ReplaceRule(
rule_type="replace",
original=TagInfo(key="term1", value="value1"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
@ -94,7 +98,7 @@ def test_replace_rule(
def test_replace_rule_no_match(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term1: data.Term,
term2: data.Term,
):
rule = ReplaceRule(
@ -102,16 +106,19 @@ def test_replace_rule_no_match(
original=TagInfo(key="term1", value="wrong_value"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].key == "term1"
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].term != term2
assert transformed_annotation.tags[0].value == "value1"
def test_build_transformation_from_config(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
annotation2: data.SoundEventAnnotation,
term1: data.Term,
term2: data.Term,
term3: data.Term,
):
config = TransformConfig(
rules=[
@ -127,20 +134,20 @@ def test_build_transformation_from_config(
),
]
)
term_registry.add_custom_term("term2")
term_registry.add_custom_term("term3")
transform = build_transformation_from_config(
config,
term_registry=term_registry,
)
transform = build_transformation_from_config(config)
transformed_annotation = transform(annotation)
assert transformed_annotation.tags[0].key == "term1"
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].term != term2
assert transformed_annotation.tags[0].value == "value2"
transformed_annotation = transform(annotation2)
assert transformed_annotation.tags[0].term == term3
assert transformed_annotation.tags[0].value == "value3"
def test_derive_tag_rule(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
@ -156,7 +163,6 @@ def test_derive_tag_rule(
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
@ -170,7 +176,6 @@ def test_derive_tag_rule(
def test_derive_tag_rule_keep_source_false(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
@ -187,7 +192,6 @@ def test_derive_tag_rule_keep_source_false(
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
@ -199,7 +203,6 @@ def test_derive_tag_rule_keep_source_false(
def test_derive_tag_rule_target_term(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
term2: data.Term,
@ -217,7 +220,6 @@ def test_derive_tag_rule_target_term(
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
@ -231,7 +233,6 @@ def test_derive_tag_rule_target_term(
def test_derive_tag_rule_import_derivation(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term1: data.Term,
tmp_path: Path,
):
@ -256,7 +257,7 @@ def my_imported_derivation(x: str) -> str:
derivation_function="temp_derivation.my_imported_derivation",
import_derivation=True,
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
@ -269,14 +270,14 @@ def my_imported_derivation(x: str) -> str:
sys.path.remove(str(tmp_path))
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
def test_derive_tag_rule_invalid_derivation():
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="nonexistent_derivation",
)
with pytest.raises(KeyError):
build_transform_from_rule(rule, term_registry=term_registry)
build_transform_from_rule(rule)
def test_build_transform_from_rule_invalid_rule_type():
@ -291,7 +292,6 @@ def test_build_transform_from_rule_invalid_rule_type():
def test_map_value_rule_target_term(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term2: data.Term,
):
rule = MapValueRule(
@ -300,7 +300,7 @@ def test_map_value_rule_target_term(
value_mapping={"value1": "value2"},
target_term_key="term2",
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
@ -308,7 +308,6 @@ def test_map_value_rule_target_term(
def test_map_value_rule_target_term_none(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term1: data.Term,
):
rule = MapValueRule(
@ -317,7 +316,7 @@ def test_map_value_rule_target_term_none(
value_mapping={"value1": "value2"},
target_term_key=None,
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transform_fn = build_transform_from_rule(rule)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value2"
@ -325,7 +324,6 @@ def test_map_value_rule_target_term_none(
def test_derive_tag_rule_target_term_none(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
@ -342,7 +340,6 @@ def test_derive_tag_rule_target_term_none(
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)

View File

@ -1,170 +0,0 @@
from collections.abc import Callable
import pytest
import torch
from soundevent import data
from batdetect2.train.augmentations import (
add_echo,
mix_audio,
)
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
def test_mix_examples(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
example1 = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
mixed = mix_audio(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
assert mixed.spectrogram.shape == example1.spectrogram.shape
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
def test_mix_examples_of_different_durations(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
duration1: float,
duration2: float,
):
recording1 = create_recording()
recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
example1 = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
mixed = mix_audio(
example1,
example2,
weight=0.3,
preprocessor=sample_preprocessor,
)
assert mixed.spectrogram.shape == example1.spectrogram.shape
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
def test_add_echo(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
with_echo = add_echo(
original,
preprocessor=sample_preprocessor,
delay=0.1,
weight=0.3,
)
assert with_echo.spectrogram.shape == original.spectrogram.shape
torch.testing.assert_close(
with_echo.size_heatmap,
original.size_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.class_heatmap,
original.class_heatmap,
atol=0,
rtol=0,
)
torch.testing.assert_close(
with_echo.detection_heatmap,
original.detection_heatmap,
atol=0,
rtol=0,
)
def test_selected_random_subclip_has_the_correct_width(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
subclip = select_subclip(
original,
input_samplerate=256_000,
output_samplerate=1000,
start=0,
duration=0.512,
)
assert subclip.spectrogram.shape[1] == 512

View File

@ -1,27 +0,0 @@
from soundevent import data
from batdetect2.train import generate_train_example
from batdetect2.typing import (
AudioLoader,
ClipLabeller,
ClipperProtocol,
PreprocessorProtocol,
)
def test_default_clip_size_is_correct(
sample_clipper: ClipperProtocol,
sample_labeller: ClipLabeller,
sample_audio_loader: AudioLoader,
clip_annotation: data.ClipAnnotation,
sample_preprocessor: PreprocessorProtocol,
):
example = generate_train_example(
clip_annotation=clip_annotation,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
clip, _, _ = sample_clipper(example)
assert clip.spectrogram.shape == (1, 128, 256)

View File

@ -56,7 +56,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation,
torch.rand([100, 100]),
torch.rand([1, 100, 100]),
min_freq=0,
max_freq=100,
targets=targets,
@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
assert size_heatmap[1, 10, 10] == 20
assert class_heatmap[pippip_index, 10, 10] == 1.0
assert class_heatmap[myomyo_index, 10, 10] == 0.0
assert detection_heatmap[10, 10] == 1.0
assert detection_heatmap[0, 10, 10] == 1.0

View File

@ -4,14 +4,16 @@ import lightning as L
import torch
from soundevent import data
from batdetect2.models import build_model
from batdetect2.train import FullTrainingConfig, TrainingModule
from batdetect2.train.train import build_training_module
from batdetect2.typing.preprocess import AudioLoader
def build_default_module():
model = build_model()
config = FullTrainingConfig()
return build_training_module(config)
return build_training_module(model, config=config)
def test_can_initialize_default_module():
@ -32,14 +34,14 @@ def test_can_save_checkpoint(
recovered = TrainingModule.load_from_checkpoint(path)
wav = torch.tensor(sample_audio_loader.load_clip(clip))
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
spec1 = module.model.preprocessor(wav)
spec2 = recovered.model.preprocessor(wav)
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
output1 = module(spec1.unsqueeze(0))
output2 = recovered(spec2.unsqueeze(0))
torch.testing.assert_close(output1, output2, rtol=0, atol=0)

View File

@ -1,230 +0,0 @@
import pytest
from soundevent import data
from soundevent.terms import get_term
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ModelOutput
from batdetect2.typing.preprocess import AudioLoader
@pytest.fixture
def build_from_config(
create_temp_yaml,
):
def build(yaml_content):
config_path = create_temp_yaml(yaml_content)
targets_config = load_target_config(config_path, field="targets")
preprocessing_config = load_preprocessing_config(
config_path,
field="preprocessing",
)
labels_config = load_label_config(config_path, field="labels")
postprocessing_config = load_postprocess_config(
config_path,
field="postprocessing",
)
targets = build_targets(targets_config)
preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler(
targets=targets,
config=labels_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
postprocessor = build_postprocessor(
preprocessor=preprocessor,
config=postprocessing_config,
)
return targets, preprocessor, labeller, postprocessor
return build
def test_encoding_decoding_roundtrip_recovers_object(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(
clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
0
),
size_preds=encoded.size_heatmap.unsqueeze(0),
class_probs=encoded.class_heatmap.unsqueeze(0),
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(
clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
0
),
size_preds=encoded.size_heatmap.unsqueeze(0),
class_probs=encoded.class_heatmap.unsqueeze(0),
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"