From cf6d0d1ccc94c5fbcb83ba157bf9d4e5ca5d9ae1 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 7 Sep 2025 11:03:46 +0100 Subject: [PATCH] Remove stale tests --- src/batdetect2/data/annotations/legacy.py | 13 +- src/batdetect2/models/__init__.py | 7 +- src/batdetect2/targets/__init__.py | 42 +-- src/batdetect2/targets/classes.py | 32 +- src/batdetect2/targets/filtering.py | 24 +- src/batdetect2/targets/terms.py | 419 +--------------------- src/batdetect2/targets/transform.py | 53 +-- src/batdetect2/train/lightning.py | 2 +- src/batdetect2/train/train.py | 9 +- tests/conftest.py | 26 +- tests/test_targets/test_targets.py | 16 +- tests/test_targets/test_terms.py | 164 +-------- tests/test_targets/test_transform.py | 105 +++--- tests/test_train/test_augmentations.py | 170 --------- tests/test_train/test_clips.py | 27 -- tests/test_train/test_labels.py | 4 +- tests/test_train/test_lightning.py | 10 +- tests/test_train/test_preprocessing.py | 230 ------------ 18 files changed, 138 insertions(+), 1215 deletions(-) delete mode 100644 tests/test_train/test_augmentations.py delete mode 100644 tests/test_train/test_clips.py delete mode 100644 tests/test_train/test_preprocessing.py diff --git a/src/batdetect2/data/annotations/legacy.py b/src/batdetect2/data/annotations/legacy.py index 7c1a383..0b443bb 100644 --- a/src/batdetect2/data/annotations/legacy.py +++ b/src/batdetect2/data/annotations/legacy.py @@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union from pydantic import BaseModel, Field from soundevent import data -from batdetect2.targets import get_term_from_key - PathLike = Union[Path, str, os.PathLike] __all__ = [] @@ -92,15 +90,15 @@ def annotation_to_sound_event( sound_event=sound_event, tags=[ data.Tag( - term=get_term_from_key(label_key), + key=label_key, # type: ignore value=annotation.label, ), data.Tag( - term=get_term_from_key(event_key), + key=event_key, # type: ignore value=annotation.event, ), data.Tag( - term=get_term_from_key(individual_key), + key=individual_key, # type: ignore value=str(annotation.individual), ), ], @@ -125,7 +123,7 @@ def file_annotation_to_clip( time_expansion=file_annotation.time_exp, tags=[ data.Tag( - term=get_term_from_key(label_key), + key=label_key, # type: ignore value=file_annotation.label, ) ], @@ -157,7 +155,8 @@ def file_annotation_to_clip_annotation( notes=notes, tags=[ data.Tag( - term=get_term_from_key(label_key), value=file_annotation.label + key=label_key, # type: ignore + value=file_annotation.label, ) ], sound_events=[ diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 1e2fe18..0cb2e9a 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -68,7 +68,10 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.targets import TargetConfig, build_targets from batdetect2.typing.models import DetectionModel -from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol +from batdetect2.typing.postprocess import ( + DetectionsTensor, + PostprocessorProtocol, +) from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.targets import TargetProtocol @@ -122,7 +125,7 @@ class Model(LightningModule): self.targets = targets self.save_hyperparameters() - def forward(self, wav: torch.Tensor) -> List[DetectionsArray]: + def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]: spec = self.preprocessor(wav) outputs = self.detector(spec) return self.postprocessor(outputs) diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 2114bcf..a635d51 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -57,14 +57,9 @@ from batdetect2.targets.rois import ( ) from batdetect2.targets.terms import ( TagInfo, - TermInfo, - TermRegistry, call_type, - default_term_registry, get_tag_from_info, - get_term_from_key, individual, - register_term, ) from batdetect2.targets.transform import ( DerivationRegistry, @@ -100,7 +95,6 @@ __all__ = [ "TargetClass", "TargetConfig", "Targets", - "TermInfo", "TransformConfig", "build_generic_class_tags", "build_roi_mapper", @@ -112,7 +106,6 @@ __all__ = [ "get_class_names_from_config", "get_derivation", "get_tag_from_info", - "get_term_from_key", "individual", "load_classes_config", "load_decoder_from_config", @@ -123,7 +116,6 @@ __all__ = [ "load_transformation_config", "load_transformation_from_config", "register_derivation", - "register_term", ] @@ -534,7 +526,6 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( def build_targets( config: Optional[TargetConfig] = None, - term_registry: TermRegistry = default_term_registry, derivation_registry: DerivationRegistry = default_derivation_registry, ) -> Targets: """Build a Targets object from a loaded TargetConfig. @@ -550,9 +541,6 @@ def build_targets( ---------- config : TargetConfig The loaded and validated unified target configuration object. - term_registry : TermRegistry, optional - The TermRegistry instance to use for resolving term keys. Defaults - to the global `batdetect2.targets.terms.term_registry`. derivation_registry : DerivationRegistry, optional The DerivationRegistry instance to use for resolving derivation function names. Defaults to the global @@ -578,25 +566,15 @@ def build_targets( ) filter_fn = ( - build_sound_event_filter( - config.filtering, - term_registry=term_registry, - ) + build_sound_event_filter(config.filtering) if config.filtering else None ) - encode_fn = build_sound_event_encoder( - config.classes, - term_registry=term_registry, - ) - decode_fn = build_sound_event_decoder( - config.classes, - term_registry=term_registry, - ) + encode_fn = build_sound_event_encoder(config.classes) + decode_fn = build_sound_event_decoder(config.classes) transform_fn = ( build_transformation_from_config( config.transforms, - term_registry=term_registry, derivation_registry=derivation_registry, ) if config.transforms @@ -604,10 +582,7 @@ def build_targets( ) roi_mapper = build_roi_mapper(config.roi) class_names = get_class_names_from_config(config.classes) - generic_class_tags = build_generic_class_tags( - config.classes, - term_registry=term_registry, - ) + generic_class_tags = build_generic_class_tags(config.classes) roi_overrides = { class_config.name: build_roi_mapper(class_config.roi) for class_config in config.classes.classes @@ -629,7 +604,6 @@ def build_targets( def load_targets( config_path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = default_term_registry, derivation_registry: DerivationRegistry = default_derivation_registry, ) -> Targets: """Load a Targets object directly from a configuration file. @@ -645,8 +619,6 @@ def load_targets( field : str, optional Dot-separated path to a nested section within the file containing the target configuration. If None, the entire file content is used. - term_registry : TermRegistry, optional - The TermRegistry instance to use. Defaults to the global default. derivation_registry : DerivationRegistry, optional The DerivationRegistry instance to use. Defaults to the global default. @@ -670,11 +642,7 @@ def load_targets( config_path, field=field, ) - return build_targets( - config, - term_registry=term_registry, - derivation_registry=derivation_registry, - ) + return build_targets(config, derivation_registry=derivation_registry) def iterate_encoded_sound_events( diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 95d339c..fed7170 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -10,8 +10,6 @@ from batdetect2.targets.rois import ROIMapperConfig from batdetect2.targets.terms import ( GENERIC_CLASS_KEY, TagInfo, - TermRegistry, - default_term_registry, get_tag_from_info, ) from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder @@ -295,10 +293,7 @@ def _encode_with_multiple_classifiers( return None -def build_sound_event_encoder( - config: ClassesConfig, - term_registry: TermRegistry = default_term_registry, -) -> SoundEventEncoder: +def build_sound_event_encoder(config: ClassesConfig) -> SoundEventEncoder: """Build a sound event encoder function from the classes configuration. The returned encoder function iterates through the class definitions in the @@ -333,8 +328,7 @@ def build_sound_event_encoder( partial( is_target_class, tags={ - get_tag_from_info(tag_info, term_registry=term_registry) - for tag_info in class_info.tags + get_tag_from_info(tag_info) for tag_info in class_info.tags }, match_all=class_info.match_type == "all", ), @@ -391,7 +385,6 @@ def _decode_class( def build_sound_event_decoder( config: ClassesConfig, - term_registry: TermRegistry = default_term_registry, raise_on_unmapped: bool = False, ) -> SoundEventDecoder: """Build a sound event decoder function from the classes configuration. @@ -433,8 +426,7 @@ def build_sound_event_decoder( else class_info.tags ) mapping[class_info.name] = [ - get_tag_from_info(tag_info, term_registry=term_registry) - for tag_info in tags_to_use + get_tag_from_info(tag_info) for tag_info in tags_to_use ] return partial( @@ -444,10 +436,7 @@ def build_sound_event_decoder( ) -def build_generic_class_tags( - config: ClassesConfig, - term_registry: TermRegistry = default_term_registry, -) -> List[data.Tag]: +def build_generic_class_tags(config: ClassesConfig) -> List[data.Tag]: """Extract and build the list of tags for the generic class from config. Converts the list of `TagInfo` objects defined in `config.generic_class` @@ -472,10 +461,7 @@ def build_generic_class_tags( If a term key specified in `config.generic_class` is not found in the provided `term_registry`. """ - return [ - get_tag_from_info(tag_info, term_registry=term_registry) - for tag_info in config.generic_class - ] + return [get_tag_from_info(tag_info) for tag_info in config.generic_class] def load_classes_config( @@ -509,9 +495,7 @@ def load_classes_config( def load_encoder_from_config( - path: data.PathLike, - field: Optional[str] = None, - term_registry: TermRegistry = default_term_registry, + path: data.PathLike, field: Optional[str] = None ) -> SoundEventEncoder: """Load a class encoder function directly from a configuration file. @@ -546,13 +530,12 @@ def load_encoder_from_config( provided `term_registry` during the build process. """ config = load_classes_config(path, field=field) - return build_sound_event_encoder(config, term_registry=term_registry) + return build_sound_event_encoder(config) def load_decoder_from_config( path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = default_term_registry, raise_on_unmapped: bool = False, ) -> SoundEventDecoder: """Load a class decoder function directly from a configuration file. @@ -594,6 +577,5 @@ def load_decoder_from_config( config = load_classes_config(path, field=field) return build_sound_event_decoder( config, - term_registry=term_registry, raise_on_unmapped=raise_on_unmapped, ) diff --git a/src/batdetect2/targets/filtering.py b/src/batdetect2/targets/filtering.py index e532cc5..462f7e4 100644 --- a/src/batdetect2/targets/filtering.py +++ b/src/batdetect2/targets/filtering.py @@ -8,8 +8,6 @@ from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.targets.terms import ( TagInfo, - TermRegistry, - default_term_registry, get_tag_from_info, ) from batdetect2.typing.targets import SoundEventFilter @@ -146,10 +144,7 @@ def equal_tags( return tags == sound_event_tags -def build_filter_from_rule( - rule: FilterRule, - term_registry: TermRegistry = default_term_registry, -) -> SoundEventFilter: +def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter: """Creates a callable filter function from a single FilterRule. Parameters @@ -168,10 +163,7 @@ def build_filter_from_rule( ValueError If the rule contains an invalid `match_type`. """ - tag_set = { - get_tag_from_info(tag_info, term_registry=term_registry) - for tag_info in rule.tags - } + tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags} if rule.match_type == "any": return partial(has_any_tag, tags=tag_set) @@ -235,7 +227,6 @@ class FilterConfig(BaseConfig): def build_sound_event_filter( config: FilterConfig, - term_registry: TermRegistry = default_term_registry, ) -> SoundEventFilter: """Builds a merged filter function from a FilterConfig object. @@ -252,10 +243,7 @@ def build_sound_event_filter( SoundEventFilter A single callable filter function that applies all defined rules. """ - filters = [ - build_filter_from_rule(rule, term_registry=term_registry) - for rule in config.rules - ] + filters = [build_filter_from_rule(rule) for rule in config.rules] return partial(_passes_all_filters, filters=filters) @@ -281,9 +269,7 @@ def load_filter_config( def load_filter_from_config( - path: data.PathLike, - field: Optional[str] = None, - term_registry: TermRegistry = default_term_registry, + path: data.PathLike, field: Optional[str] = None, ) -> SoundEventFilter: """Loads filter configuration from a file and builds the filter function. @@ -304,4 +290,4 @@ def load_filter_from_config( The final merged filter function ready to be used. """ config = load_filter_config(path=path, field=field) - return build_sound_event_filter(config, term_registry=term_registry) + return build_sound_event_filter(config) diff --git a/src/batdetect2/targets/terms.py b/src/batdetect2/targets/terms.py index 6247b67..88b3576 100644 --- a/src/batdetect2/targets/terms.py +++ b/src/batdetect2/targets/terms.py @@ -1,33 +1,22 @@ -"""Manages the vocabulary (Terms and Tags) for defining training targets. +"""Manages the vocabulary for defining training targets. This module provides the necessary tools to declare, register, and manage the set of `soundevent.data.Term` objects used throughout the `batdetect2.targets` sub-package. It establishes a consistent vocabulary for filtering, transforming, and classifying sound events based on their annotations (Tags). -The core component is the `TermRegistry`, which maps unique string keys -(aliases) to specific `Term` definitions. This allows users to refer to complex -terms using simple, consistent keys in configuration files and code. - -Terms can be pre-defined, loaded from the `soundevent.terms` library, defined -programmatically, or loaded from external configuration files (e.g., YAML). +Terms can be pre-defined, loaded from the `soundevent.terms` library or defined +programmatically. """ -from collections.abc import Mapping -from inspect import getmembers -from typing import Dict, List, Optional - -from pydantic import BaseModel, Field +from pydantic import BaseModel from soundevent import data, terms -from batdetect2.configs import load_config - __all__ = [ "call_type", "individual", "data_source", "get_tag_from_info", - "TermInfo", "TagInfo", ] @@ -98,255 +87,6 @@ terms.register_term_set( ) -class TermRegistry(Mapping[str, data.Term]): - """Manages a registry mapping unique keys to Term definitions. - - This class acts as the central repository for the vocabulary of terms - used within the target definition process. It allows registering terms - with simple string keys and retrieving them consistently. - """ - - def __init__(self, terms: Optional[Dict[str, data.Term]] = None): - """Initializes the TermRegistry. - - Parameters - ---------- - terms : dict[str, soundevent.data.Term], optional - An optional dictionary of initial key-to-Term mappings - to populate the registry with. Defaults to an empty registry. - """ - self._terms: Dict[str, data.Term] = terms or {} - - def __getitem__(self, key: str) -> data.Term: - return self._terms[key] - - def __len__(self) -> int: - return len(self._terms) - - def __iter__(self): - return iter(self._terms) - - def add_term(self, key: str, term: data.Term) -> None: - """Adds a Term object to the registry with the specified key. - - Parameters - ---------- - key : str - The unique string key to associate with the term. - term : soundevent.data.Term - The soundevent.data.Term object to register. - - Raises - ------ - KeyError - If a term with the provided key already exists in the - registry. - """ - if key in self._terms: - raise KeyError("A term with the provided key already exists.") - - self._terms[key] = term - - def get_term(self, key: str) -> data.Term: - """Retrieves a registered term by its unique key. - - Parameters - ---------- - key : str - The unique string key of the term to retrieve. - - Returns - ------- - soundevent.data.Term - The corresponding soundevent.data.Term object. - - Raises - ------ - KeyError - If no term with the specified key is found, with a - helpful message suggesting listing available keys. - """ - try: - return self._terms[key] - except KeyError as err: - raise KeyError( - "No term found for key " - f"'{key}'. Ensure it is registered or loaded. " - f"Available keys: {', '.join(self.get_keys())}" - ) from err - - def add_custom_term( - self, - key: str, - name: Optional[str] = None, - uri: Optional[str] = None, - label: Optional[str] = None, - definition: Optional[str] = None, - ) -> data.Term: - """Creates a new Term from attributes and adds it to the registry. - - This is useful for defining terms directly in code or when loading - from configuration files where only attributes are provided. - - If optional fields (`name`, `label`, `definition`) are not provided, - reasonable defaults are used (`key` for name/label, "Unknown" for - definition). - - Parameters - ---------- - key : str - The unique string key for the new term. - name : str, optional - The name for the new term (defaults to `key`). - uri : str, optional - The URI for the new term (optional). - label : str, optional - The display label for the new term (defaults to `key`). - definition : str, optional - The definition for the new term (defaults to "Unknown"). - - Returns - ------- - soundevent.data.Term - The newly created and registered soundevent.data.Term object. - - Raises - ------ - KeyError - If a term with the provided key already exists. - """ - term = data.Term( - name=name or key, - label=label or key, - uri=uri, - definition=definition or "Unknown", - ) - self.add_term(key, term) - return term - - def get_keys(self) -> List[str]: - """Returns a list of all keys currently registered. - - Returns - ------- - list[str] - A list of strings representing the keys of all registered terms. - """ - return list(self._terms.keys()) - - def get_terms(self) -> List[data.Term]: - """Returns a list of all registered terms. - - Returns - ------- - list[soundevent.data.Term] - A list containing all registered Term objects. - """ - return list(self._terms.values()) - - def remove_key(self, key: str) -> None: - del self._terms[key] - - -default_term_registry = TermRegistry( - terms=dict( - [ - *getmembers(terms, lambda x: isinstance(x, data.Term)), - ("event", call_type), - ("species", terms.scientific_name), - ("individual", individual), - ("data_source", data_source), - (GENERIC_CLASS_KEY, generic_class), - ] - ) -) -"""The default, globally accessible TermRegistry instance. - -It is pre-populated with standard terms from `soundevent.terms` and common -terms defined in this module (`call_type`, `individual`, `generic_class`). -Functions in this module use this registry by default unless another instance -is explicitly passed. -""" - - -def get_term_from_key( - key: str, - term_registry: Optional[TermRegistry] = None, -) -> data.Term: - """Convenience function to retrieve a term by key from a registry. - - Uses the global default registry unless a specific `term_registry` - instance is provided. - - Parameters - ---------- - key : str - The unique key of the term to retrieve. - term_registry : TermRegistry, optional - The TermRegistry instance to search in. Defaults to the global - `registry`. - - Returns - ------- - soundevent.data.Term - The corresponding soundevent.data.Term object. - - Raises - ------ - KeyError - If the key is not found in the specified registry. - """ - term = terms.get_term(key) - - if term: - return term - - term_registry = term_registry or default_term_registry - return term_registry.get_term(key) - - -def get_term_keys( - term_registry: TermRegistry = default_term_registry, -) -> List[str]: - """Convenience function to get all registered keys from a registry. - - Uses the global default registry unless a specific `term_registry` - instance is provided. - - Parameters - ---------- - term_registry : TermRegistry, optional - The TermRegistry instance to query. Defaults to the global `registry`. - - Returns - ------- - list[str] - A list of strings representing the keys of all registered terms. - """ - return term_registry.get_keys() - - -def get_terms( - term_registry: TermRegistry = default_term_registry, -) -> List[data.Term]: - """Convenience function to get all registered terms from a registry. - - Uses the global default registry unless a specific `term_registry` - instance is provided. - - Parameters - ---------- - term_registry : TermRegistry, optional - The TermRegistry instance to query. Defaults to the global `registry`. - - Returns - ------- - list[soundevent.data.Term] - A list containing all registered Term objects. - """ - return term_registry.get_terms() - - class TagInfo(BaseModel): """Represents information needed to define a specific Tag. @@ -360,31 +100,25 @@ class TagInfo(BaseModel): value : str The value of the tag (e.g., "Myotis myotis", "Echolocation"). key : str, default="class" - The key (alias) of the term associated with this tag, as - registered in the TermRegistry. Defaults to "class", implying - it represents a classification target label by default. + The key (alias) of the term associated with this tag. Defaults to + "class", implying it represents a classification target label by + default. """ value: str key: str = GENERIC_CLASS_KEY -def get_tag_from_info( - tag_info: TagInfo, - term_registry: Optional[TermRegistry] = None, -) -> data.Tag: +def get_tag_from_info(tag_info: TagInfo) -> data.Tag: """Creates a soundevent.data.Tag object from TagInfo data. - Looks up the term using the key in the provided `tag_info` from the - specified registry and constructs a Tag object. + Looks up the term using the key in the provided `tag_info` and constructs a + Tag object. Parameters ---------- tag_info : TagInfo The TagInfo object containing the value and term key. - term_registry : TermRegistry, optional - The TermRegistry instance to use for term lookup. Defaults to the - global `registry`. Returns ------- @@ -394,132 +128,11 @@ def get_tag_from_info( Raises ------ KeyError - If the term key specified in `tag_info.key` is not found - in the registry. + If the term key specified in `tag_info.key` is not found. """ - term_registry = term_registry or default_term_registry - term = get_term_from_key(tag_info.key, term_registry=term_registry) + term = terms.get_term(tag_info.key) + + if not term: + raise KeyError(f"Key {tag_info.key} not found") + return data.Tag(term=term, value=tag_info.value) - - -class TermInfo(BaseModel): - """Represents the definition of a Term within a configuration file. - - This model allows users to define custom terms directly in configuration - files (e.g., YAML) which can then be loaded into the TermRegistry. - It mirrors the parameters of `TermRegistry.add_custom_term`. - - Attributes - ---------- - key : str - The unique key (alias) that will be used to register and - reference this term. - label : str, optional - The optional display label for the term. Defaults to `key` - if not provided during registration. - name : str, optional - The optional formal name for the term. Defaults to `key` - if not provided during registration. - uri : str, optional - The optional URI identifying the term (e.g., from a standard - vocabulary). - definition : str, optional - The optional textual definition of the term. Defaults to - "Unknown" if not provided during registration. - """ - - key: str - label: Optional[str] = None - name: Optional[str] = None - uri: Optional[str] = None - definition: Optional[str] = None - - -class TermConfig(BaseModel): - """Pydantic schema for loading a list of term definitions from config. - - This model typically corresponds to a section in a configuration file - (e.g., YAML) containing a list of term definitions to be registered. - - Attributes - ---------- - terms : list[TermInfo] - A list of TermInfo objects, each defining a term to be - registered. Defaults to an empty list. - - Examples - -------- - Example YAML structure: - - ```yaml - terms: - - key: species - uri: dwc:scientificName - label: Scientific Name - - key: my_custom_term - name: My Custom Term - definition: Describes a specific project attribute. - # ... more TermInfo definitions - ``` - """ - - terms: List[TermInfo] = Field(default_factory=list) - - -def load_terms_from_config( - path: data.PathLike, - field: Optional[str] = None, - term_registry: TermRegistry = default_term_registry, -) -> Dict[str, data.Term]: - """Loads term definitions from a configuration file and registers them. - - Parses a configuration file (e.g., YAML) using the TermConfig schema, - extracts the list of TermInfo definitions, and adds each one as a - custom term to the specified TermRegistry instance. - - Parameters - ---------- - path : data.PathLike - The path to the configuration file. - field : str, optional - Optional key indicating a specific section within the config - file where the 'terms' list is located. If None, expects the - list directly at the top level or within a structure matching - TermConfig schema. - term_registry : TermRegistry, optional - The TermRegistry instance to add the loaded terms to. Defaults to - the global `registry`. - - Returns - ------- - dict[str, soundevent.data.Term] - A dictionary mapping the keys of the newly added terms to their - corresponding Term objects. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the TermConfig schema. - KeyError - If a term key loaded from the config conflicts with a key - already present in the registry. - """ - data = load_config(path, schema=TermConfig, field=field) - return { - info.key: term_registry.add_custom_term( - info.key, - name=info.name, - uri=info.uri, - label=info.label, - definition=info.definition, - ) - for info in data.terms - } - - -def register_term( - key: str, term: data.Term, registry: TermRegistry = default_term_registry -) -> None: - registry.add_term(key, term) diff --git a/src/batdetect2/targets/transform.py b/src/batdetect2/targets/transform.py index 29056a7..b71e658 100644 --- a/src/batdetect2/targets/transform.py +++ b/src/batdetect2/targets/transform.py @@ -12,15 +12,10 @@ from typing import ( ) from pydantic import Field -from soundevent import data +from soundevent import data, terms from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets.terms import ( - TagInfo, - TermRegistry, - get_tag_from_info, - get_term_from_key, -) +from batdetect2.targets.terms import TagInfo, get_tag_from_info __all__ = [ "DerivationRegistry", @@ -466,7 +461,6 @@ TranformationRule = Annotated[ def build_transform_from_rule( rule: TranformationRule, derivation_registry: Optional[DerivationRegistry] = None, - term_registry: Optional[TermRegistry] = None, ) -> SoundEventTransformation: """Build a specific SoundEventTransformation function from a rule config. @@ -497,29 +491,21 @@ def build_transform_from_rule( If dynamic import of a derivation function fails. """ if rule.rule_type == "replace": - source = get_tag_from_info( - rule.original, - term_registry=term_registry, - ) - target = get_tag_from_info( - rule.replacement, - term_registry=term_registry, - ) + source = get_tag_from_info(rule.original) + target = get_tag_from_info(rule.replacement) return partial(replace_tag_transform, source=source, target=target) if rule.rule_type == "derive_tag": - source_term = get_term_from_key( - rule.source_term_key, - term_registry=term_registry, - ) + source_term = terms.get_term(rule.source_term_key) target_term = ( - get_term_from_key( - rule.target_term_key, - term_registry=term_registry, - ) + terms.get_term(rule.target_term_key) if rule.target_term_key else source_term ) + + if source_term is None or target_term is None: + raise KeyError("Terms not found") + derivation = get_derivation( key=rule.derivation_function, import_derivation=rule.import_derivation, @@ -534,18 +520,16 @@ def build_transform_from_rule( ) if rule.rule_type == "map_value": - source_term = get_term_from_key( - rule.source_term_key, - term_registry=term_registry, - ) + source_term = terms.get_term(rule.source_term_key) target_term = ( - get_term_from_key( - rule.target_term_key, - term_registry=term_registry, - ) + terms.get_term(rule.target_term_key) if rule.target_term_key else source_term ) + + if source_term is None or target_term is None: + raise KeyError("Terms not found") + return partial( map_value_transform, source_term=source_term, @@ -555,6 +539,7 @@ def build_transform_from_rule( # Handle unknown rule type valid_options = ["replace", "derive_tag", "map_value"] + # Should be caught by Pydantic validation, but good practice raise ValueError( f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. " @@ -565,7 +550,6 @@ def build_transform_from_rule( def build_transformation_from_config( config: TransformConfig, derivation_registry: Optional[DerivationRegistry] = None, - term_registry: Optional[TermRegistry] = None, ) -> SoundEventTransformation: """Build a composite transformation function from a TransformConfig. @@ -591,7 +575,6 @@ def build_transformation_from_config( build_transform_from_rule( rule, derivation_registry=derivation_registry, - term_registry=term_registry, ) for rule in config.rules ] @@ -640,7 +623,6 @@ def load_transformation_from_config( path: data.PathLike, field: Optional[str] = None, derivation_registry: Optional[DerivationRegistry] = None, - term_registry: Optional[TermRegistry] = None, ) -> SoundEventTransformation: """Load transformation config from a file and build the final function. @@ -678,7 +660,6 @@ def load_transformation_from_config( return build_transformation_from_config( config, derivation_registry=derivation_registry, - term_registry=term_registry, ) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 6e33ed9..bc9edd3 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -31,7 +31,7 @@ class TrainingModule(L.LightningModule): self.save_hyperparameters(logger=False) def forward(self, spec: torch.Tensor) -> ModelOutput: - return self.model(spec) + return self.model.detector(spec) def training_step(self, batch: TrainExample): outputs = self.model.detector(batch.spec) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 62cd7a9..aaa67f8 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -99,7 +99,7 @@ def train( module = build_training_module( model, config, - batches_per_epoch=len(train_dataloader), + t_max=config.train.t_max * len(train_dataloader), ) logger.info("Starting main training loop...") @@ -113,15 +113,16 @@ def train( def build_training_module( model: Model, - config: FullTrainingConfig, - batches_per_epoch: int, + config: Optional[FullTrainingConfig] = None, + t_max: int = 200, ) -> TrainingModule: + config = config or FullTrainingConfig() loss = build_loss(config=config.train.loss) return TrainingModule( model=model, loss=loss, learning_rate=config.train.learning_rate, - t_max=config.train.t_max * batches_per_epoch, + t_max=t_max, ) diff --git a/tests/conftest.py b/tests/conftest.py index d036c62..36c5c9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,6 @@ from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess.audio import build_audio_loader from batdetect2.targets import ( TargetConfig, - TermRegistry, build_targets, call_type, ) @@ -355,18 +354,6 @@ def create_annotation_project(): return factory -@pytest.fixture -def sample_term_registry() -> TermRegistry: - """Fixture for a sample TermRegistry.""" - registry = TermRegistry() - registry.add_custom_term("class") - registry.add_custom_term("order") - registry.add_custom_term("species") - registry.add_custom_term("call_type") - registry.add_custom_term("quality") - return registry - - @pytest.fixture def sample_preprocessor() -> PreprocessorProtocol: return build_preprocessor() @@ -399,7 +386,6 @@ def pippip_tag() -> TagInfo: @pytest.fixture def sample_target_config( - sample_term_registry: TermRegistry, bat_tag: TagInfo, noise_tag: TagInfo, myomyo_tag: TagInfo, @@ -422,12 +408,8 @@ def sample_target_config( @pytest.fixture def sample_targets( sample_target_config: TargetConfig, - sample_term_registry: TermRegistry, ) -> TargetProtocol: - return build_targets( - sample_target_config, - term_registry=sample_term_registry, - ) + return build_targets(sample_target_config) @pytest.fixture @@ -443,10 +425,8 @@ def sample_labeller( @pytest.fixture -def sample_clipper( - sample_preprocessor: PreprocessorProtocol, -) -> ClipperProtocol: - return build_clipper(preprocessor=sample_preprocessor) +def sample_clipper() -> ClipperProtocol: + return build_clipper() @pytest.fixture diff --git a/tests/test_targets/test_targets.py b/tests/test_targets/test_targets.py index bb4d00f..8324807 100644 --- a/tests/test_targets/test_targets.py +++ b/tests/test_targets/test_targets.py @@ -1,16 +1,14 @@ from collections.abc import Callable from pathlib import Path -from soundevent import data +from soundevent import data, terms from batdetect2.targets import build_targets, load_target_config -from batdetect2.targets.terms import get_term_from_key def test_can_override_default_roi_mapper_per_class( create_temp_yaml: Callable[..., Path], recording: data.Recording, - sample_term_registry, ): yaml_content = """ roi: @@ -36,11 +34,13 @@ def test_can_override_default_roi_mapper_per_class( config_path = create_temp_yaml(yaml_content) config = load_target_config(config_path) - targets = build_targets(config, term_registry=sample_term_registry) + targets = build_targets(config) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) - species = get_term_from_key("species", term_registry=sample_term_registry) + species = terms.get_term("species") + assert species is not None + se1 = data.SoundEventAnnotation( sound_event=data.SoundEvent(recording=recording, geometry=geometry), tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], @@ -62,7 +62,6 @@ def test_can_override_default_roi_mapper_per_class( # TODO: rename this test function def test_roi_is_recovered_roundtrip_even_with_overriders( create_temp_yaml, - sample_term_registry, recording, ): yaml_content = """ @@ -89,11 +88,12 @@ def test_roi_is_recovered_roundtrip_even_with_overriders( config_path = create_temp_yaml(yaml_content) config = load_target_config(config_path) - targets = build_targets(config, term_registry=sample_term_registry) + targets = build_targets(config) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) - species = get_term_from_key("species", term_registry=sample_term_registry) + species = terms.get_term("species") + assert species is not None se1 = data.SoundEventAnnotation( sound_event=data.SoundEvent(recording=recording, geometry=geometry), tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], diff --git a/tests/test_targets/test_terms.py b/tests/test_targets/test_terms.py index 37a997e..74fa927 100644 --- a/tests/test_targets/test_terms.py +++ b/tests/test_targets/test_terms.py @@ -1,98 +1,7 @@ import pytest -import yaml -from soundevent import data from batdetect2.targets import terms -from batdetect2.targets.terms import ( - TagInfo, - TermRegistry, - load_terms_from_config, -) - - -def test_term_registry_initialization(): - registry = TermRegistry() - assert registry._terms == {} - - initial_terms = { - "test_term": data.Term(name="test", label="Test", definition="test") - } - registry = TermRegistry(terms=initial_terms) - assert registry._terms == initial_terms - - -def test_term_registry_add_term(): - registry = TermRegistry() - term = data.Term(name="test", label="Test", definition="test") - registry.add_term("test_key", term) - assert registry._terms["test_key"] == term - - -def test_term_registry_get_term(): - registry = TermRegistry() - term = data.Term(name="test", label="Test", definition="test") - registry.add_term("test_key", term) - retrieved_term = registry.get_term("test_key") - assert retrieved_term == term - - -def test_term_registry_add_custom_term(): - registry = TermRegistry() - term = registry.add_custom_term( - "custom_key", name="custom", label="Custom", definition="A custom term" - ) - assert registry._terms["custom_key"] == term - assert term.name == "custom" - assert term.label == "Custom" - assert term.definition == "A custom term" - - -def test_term_registry_add_duplicate_term(): - registry = TermRegistry() - term = data.Term(name="test", label="Test", definition="test") - registry.add_term("test_key", term) - with pytest.raises(KeyError): - registry.add_term("test_key", term) - - -def test_term_registry_get_term_not_found(): - registry = TermRegistry() - with pytest.raises(KeyError): - registry.get_term("non_existent_key") - - -def test_term_registry_get_keys(): - registry = TermRegistry() - term1 = data.Term(name="test1", label="Test1", definition="test") - term2 = data.Term(name="test2", label="Test2", definition="test") - registry.add_term("key1", term1) - registry.add_term("key2", term2) - keys = registry.get_keys() - assert set(keys) == {"key1", "key2"} - - -def test_get_term_from_key(): - term = terms.get_term_from_key("event") - assert term == terms.call_type - - custom_registry = TermRegistry() - custom_term = data.Term(name="custom", label="Custom", definition="test") - custom_registry.add_term("custom_key", custom_term) - term = terms.get_term_from_key("custom_key", term_registry=custom_registry) - assert term == custom_term - - -def test_get_term_keys(): - keys = terms.get_term_keys() - assert "event" in keys - assert "individual" in keys - assert terms.GENERIC_CLASS_KEY in keys - - custom_registry = TermRegistry() - custom_term = data.Term(name="custom", label="Custom", definition="test") - custom_registry.add_term("custom_key", custom_term) - keys = terms.get_term_keys(term_registry=custom_registry) - assert "custom_key" in keys +from batdetect2.targets.terms import TagInfo def test_tag_info_and_get_tag_from_info(): @@ -106,74 +15,3 @@ def test_get_tag_from_info_key_not_found(): tag_info = TagInfo(value="test", key="non_existent_key") with pytest.raises(KeyError): terms.get_tag_from_info(tag_info) - - -def test_load_terms_from_config(tmp_path): - term_registry = TermRegistry() - config_data = { - "terms": [ - { - "key": "species", - "name": "dwc:scientificName", - "label": "Scientific Name", - }, - { - "key": "my_custom_term", - "name": "soundevent:custom_term", - "definition": "Describes a specific project attribute", - }, - ] - } - config_file = tmp_path / "config.yaml" - with open(config_file, "w") as f: - yaml.dump(config_data, f) - - loaded_terms = load_terms_from_config( - config_file, - term_registry=term_registry, - ) - assert "species" in loaded_terms - assert "my_custom_term" in loaded_terms - assert loaded_terms["species"].name == "dwc:scientificName" - assert loaded_terms["my_custom_term"].name == "soundevent:custom_term" - - -def test_load_terms_from_config_file_not_found(): - with pytest.raises(FileNotFoundError): - load_terms_from_config("non_existent_file.yaml") - - -def test_load_terms_from_config_validation_error(tmp_path): - config_data = { - "terms": [ - { - "key": "species", - "uri": "dwc:scientificName", - "label": 123, - }, # Invalid label type - ] - } - config_file = tmp_path / "config.yaml" - with open(config_file, "w") as f: - yaml.dump(config_data, f) - - with pytest.raises(ValueError): - load_terms_from_config(config_file) - - -def test_load_terms_from_config_key_already_exists(tmp_path): - config_data = { - "terms": [ - { - "key": "event", - "uri": "dwc:scientificName", - "label": "Scientific Name", - }, # Duplicate key - ] - } - config_file = tmp_path / "config.yaml" - with open(config_file, "w") as f: - yaml.dump(config_data, f) - - with pytest.raises(KeyError): - load_terms_from_config(config_file) diff --git a/tests/test_targets/test_transform.py b/tests/test_targets/test_transform.py index ababd8d..92c6698 100644 --- a/tests/test_targets/test_transform.py +++ b/tests/test_targets/test_transform.py @@ -1,7 +1,7 @@ from pathlib import Path import pytest -from soundevent import data +from soundevent import data, terms from batdetect2.targets import ( DeriveTagRule, @@ -11,31 +11,36 @@ from batdetect2.targets import ( TransformConfig, build_transformation_from_config, ) -from batdetect2.targets.terms import TermRegistry from batdetect2.targets.transform import ( DerivationRegistry, build_transform_from_rule, ) -@pytest.fixture -def term_registry(): - return TermRegistry() - - @pytest.fixture def derivation_registry(): return DerivationRegistry() @pytest.fixture -def term1(term_registry: TermRegistry) -> data.Term: - return term_registry.add_custom_term(key="term1") +def term1() -> data.Term: + term = data.Term(label="Term 1", definition="unknown", name="test:term1") + terms.add_term(term, key="term1", force=True) + return term @pytest.fixture -def term2(term_registry: TermRegistry) -> data.Term: - return term_registry.add_custom_term(key="term2") +def term2() -> data.Term: + term = data.Term(label="Term 2", definition="unknown", name="test:term2") + terms.add_term(term, key="term2", force=True) + return term + + +@pytest.fixture +def term3() -> data.Term: + term = data.Term(label="Term 3", definition="unknown", name="test:term3") + terms.add_term(term, key="term3", force=True) + return term @pytest.fixture @@ -47,46 +52,45 @@ def annotation( sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")] ) +@pytest.fixture +def annotation2( + sound_event: data.SoundEvent, + term2: data.Term, +) -> data.SoundEventAnnotation: + return data.SoundEventAnnotation( + sound_event=sound_event, tags=[data.Tag(term=term2, value="value2")] + ) -def test_map_value_rule( - annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, -): + +def test_map_value_rule(annotation: data.SoundEventAnnotation): rule = MapValueRule( rule_type="map_value", source_term_key="term1", value_mapping={"value1": "value2"}, ) - transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transform_fn = build_transform_from_rule(rule) transformed_annotation = transform_fn(annotation) assert transformed_annotation.tags[0].value == "value2" -def test_map_value_rule_no_match( - annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, -): +def test_map_value_rule_no_match(annotation: data.SoundEventAnnotation): rule = MapValueRule( rule_type="map_value", source_term_key="term1", value_mapping={"other_value": "value2"}, ) - transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transform_fn = build_transform_from_rule(rule) transformed_annotation = transform_fn(annotation) assert transformed_annotation.tags[0].value == "value1" -def test_replace_rule( - annotation: data.SoundEventAnnotation, - term2: data.Term, - term_registry: TermRegistry, -): +def test_replace_rule(annotation: data.SoundEventAnnotation, term2: data.Term): rule = ReplaceRule( rule_type="replace", original=TagInfo(key="term1", value="value1"), replacement=TagInfo(key="term2", value="value2"), ) - transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transform_fn = build_transform_from_rule(rule) transformed_annotation = transform_fn(annotation) assert transformed_annotation.tags[0].term == term2 assert transformed_annotation.tags[0].value == "value2" @@ -94,7 +98,7 @@ def test_replace_rule( def test_replace_rule_no_match( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, + term1: data.Term, term2: data.Term, ): rule = ReplaceRule( @@ -102,16 +106,19 @@ def test_replace_rule_no_match( original=TagInfo(key="term1", value="wrong_value"), replacement=TagInfo(key="term2", value="value2"), ) - transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transform_fn = build_transform_from_rule(rule) transformed_annotation = transform_fn(annotation) - assert transformed_annotation.tags[0].key == "term1" + assert transformed_annotation.tags[0].term == term1 assert transformed_annotation.tags[0].term != term2 assert transformed_annotation.tags[0].value == "value1" def test_build_transformation_from_config( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, + annotation2: data.SoundEventAnnotation, + term1: data.Term, + term2: data.Term, + term3: data.Term, ): config = TransformConfig( rules=[ @@ -127,20 +134,20 @@ def test_build_transformation_from_config( ), ] ) - term_registry.add_custom_term("term2") - term_registry.add_custom_term("term3") - transform = build_transformation_from_config( - config, - term_registry=term_registry, - ) + transform = build_transformation_from_config(config) + transformed_annotation = transform(annotation) - assert transformed_annotation.tags[0].key == "term1" + assert transformed_annotation.tags[0].term == term1 + assert transformed_annotation.tags[0].term != term2 assert transformed_annotation.tags[0].value == "value2" + transformed_annotation = transform(annotation2) + assert transformed_annotation.tags[0].term == term3 + assert transformed_annotation.tags[0].value == "value3" + def test_derive_tag_rule( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, derivation_registry: DerivationRegistry, term1: data.Term, ): @@ -156,7 +163,6 @@ def test_derive_tag_rule( ) transform_fn = build_transform_from_rule( rule, - term_registry=term_registry, derivation_registry=derivation_registry, ) transformed_annotation = transform_fn(annotation) @@ -170,7 +176,6 @@ def test_derive_tag_rule( def test_derive_tag_rule_keep_source_false( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, derivation_registry: DerivationRegistry, term1: data.Term, ): @@ -187,7 +192,6 @@ def test_derive_tag_rule_keep_source_false( ) transform_fn = build_transform_from_rule( rule, - term_registry=term_registry, derivation_registry=derivation_registry, ) transformed_annotation = transform_fn(annotation) @@ -199,7 +203,6 @@ def test_derive_tag_rule_keep_source_false( def test_derive_tag_rule_target_term( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, derivation_registry: DerivationRegistry, term1: data.Term, term2: data.Term, @@ -217,7 +220,6 @@ def test_derive_tag_rule_target_term( ) transform_fn = build_transform_from_rule( rule, - term_registry=term_registry, derivation_registry=derivation_registry, ) transformed_annotation = transform_fn(annotation) @@ -231,7 +233,6 @@ def test_derive_tag_rule_target_term( def test_derive_tag_rule_import_derivation( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, term1: data.Term, tmp_path: Path, ): @@ -256,7 +257,7 @@ def my_imported_derivation(x: str) -> str: derivation_function="temp_derivation.my_imported_derivation", import_derivation=True, ) - transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transform_fn = build_transform_from_rule(rule) transformed_annotation = transform_fn(annotation) assert len(transformed_annotation.tags) == 2 @@ -269,14 +270,14 @@ def my_imported_derivation(x: str) -> str: sys.path.remove(str(tmp_path)) -def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry): +def test_derive_tag_rule_invalid_derivation(): rule = DeriveTagRule( rule_type="derive_tag", source_term_key="term1", derivation_function="nonexistent_derivation", ) with pytest.raises(KeyError): - build_transform_from_rule(rule, term_registry=term_registry) + build_transform_from_rule(rule) def test_build_transform_from_rule_invalid_rule_type(): @@ -291,7 +292,6 @@ def test_build_transform_from_rule_invalid_rule_type(): def test_map_value_rule_target_term( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, term2: data.Term, ): rule = MapValueRule( @@ -300,7 +300,7 @@ def test_map_value_rule_target_term( value_mapping={"value1": "value2"}, target_term_key="term2", ) - transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transform_fn = build_transform_from_rule(rule) transformed_annotation = transform_fn(annotation) assert transformed_annotation.tags[0].term == term2 assert transformed_annotation.tags[0].value == "value2" @@ -308,7 +308,6 @@ def test_map_value_rule_target_term( def test_map_value_rule_target_term_none( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, term1: data.Term, ): rule = MapValueRule( @@ -317,7 +316,7 @@ def test_map_value_rule_target_term_none( value_mapping={"value1": "value2"}, target_term_key=None, ) - transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transform_fn = build_transform_from_rule(rule) transformed_annotation = transform_fn(annotation) assert transformed_annotation.tags[0].term == term1 assert transformed_annotation.tags[0].value == "value2" @@ -325,7 +324,6 @@ def test_map_value_rule_target_term_none( def test_derive_tag_rule_target_term_none( annotation: data.SoundEventAnnotation, - term_registry: TermRegistry, derivation_registry: DerivationRegistry, term1: data.Term, ): @@ -342,7 +340,6 @@ def test_derive_tag_rule_target_term_none( ) transform_fn = build_transform_from_rule( rule, - term_registry=term_registry, derivation_registry=derivation_registry, ) transformed_annotation = transform_fn(annotation) diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py deleted file mode 100644 index 344df17..0000000 --- a/tests/test_train/test_augmentations.py +++ /dev/null @@ -1,170 +0,0 @@ -from collections.abc import Callable - -import pytest -import torch -from soundevent import data - -from batdetect2.train.augmentations import ( - add_echo, - mix_audio, -) -from batdetect2.train.clips import select_subclip -from batdetect2.train.preprocess import generate_train_example -from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol - - -def test_mix_examples( - sample_preprocessor: PreprocessorProtocol, - sample_audio_loader: AudioLoader, - sample_labeller: ClipLabeller, - create_recording: Callable[..., data.Recording], -): - recording1 = create_recording() - recording2 = create_recording() - - clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) - clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8) - - clip_annotation_1 = data.ClipAnnotation(clip=clip1) - clip_annotation_2 = data.ClipAnnotation(clip=clip2) - - example1 = generate_train_example( - clip_annotation_1, - audio_loader=sample_audio_loader, - preprocessor=sample_preprocessor, - labeller=sample_labeller, - ) - example2 = generate_train_example( - clip_annotation_2, - audio_loader=sample_audio_loader, - preprocessor=sample_preprocessor, - labeller=sample_labeller, - ) - - mixed = mix_audio( - example1, - example2, - weight=0.3, - preprocessor=sample_preprocessor, - ) - - assert mixed.spectrogram.shape == example1.spectrogram.shape - assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape - assert mixed.size_heatmap.shape == example1.size_heatmap.shape - assert mixed.class_heatmap.shape == example1.class_heatmap.shape - - -@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) -@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7]) -def test_mix_examples_of_different_durations( - sample_preprocessor: PreprocessorProtocol, - sample_audio_loader: AudioLoader, - sample_labeller: ClipLabeller, - create_recording: Callable[..., data.Recording], - duration1: float, - duration2: float, -): - recording1 = create_recording() - recording2 = create_recording() - - clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1) - clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2) - - clip_annotation_1 = data.ClipAnnotation(clip=clip1) - clip_annotation_2 = data.ClipAnnotation(clip=clip2) - - example1 = generate_train_example( - clip_annotation_1, - audio_loader=sample_audio_loader, - preprocessor=sample_preprocessor, - labeller=sample_labeller, - ) - example2 = generate_train_example( - clip_annotation_2, - audio_loader=sample_audio_loader, - preprocessor=sample_preprocessor, - labeller=sample_labeller, - ) - - mixed = mix_audio( - example1, - example2, - weight=0.3, - preprocessor=sample_preprocessor, - ) - - assert mixed.spectrogram.shape == example1.spectrogram.shape - assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape - assert mixed.size_heatmap.shape == example1.size_heatmap.shape - assert mixed.class_heatmap.shape == example1.class_heatmap.shape - - -def test_add_echo( - sample_preprocessor: PreprocessorProtocol, - sample_audio_loader: AudioLoader, - sample_labeller: ClipLabeller, - create_recording: Callable[..., data.Recording], -): - recording1 = create_recording() - clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) - clip_annotation_1 = data.ClipAnnotation(clip=clip1) - - original = generate_train_example( - clip_annotation_1, - audio_loader=sample_audio_loader, - preprocessor=sample_preprocessor, - labeller=sample_labeller, - ) - with_echo = add_echo( - original, - preprocessor=sample_preprocessor, - delay=0.1, - weight=0.3, - ) - - assert with_echo.spectrogram.shape == original.spectrogram.shape - torch.testing.assert_close( - with_echo.size_heatmap, - original.size_heatmap, - atol=0, - rtol=0, - ) - torch.testing.assert_close( - with_echo.class_heatmap, - original.class_heatmap, - atol=0, - rtol=0, - ) - torch.testing.assert_close( - with_echo.detection_heatmap, - original.detection_heatmap, - atol=0, - rtol=0, - ) - - -def test_selected_random_subclip_has_the_correct_width( - sample_preprocessor: PreprocessorProtocol, - sample_audio_loader: AudioLoader, - sample_labeller: ClipLabeller, - create_recording: Callable[..., data.Recording], -): - recording1 = create_recording() - clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) - clip_annotation_1 = data.ClipAnnotation(clip=clip1) - - original = generate_train_example( - clip_annotation_1, - audio_loader=sample_audio_loader, - preprocessor=sample_preprocessor, - labeller=sample_labeller, - ) - - subclip = select_subclip( - original, - input_samplerate=256_000, - output_samplerate=1000, - start=0, - duration=0.512, - ) - assert subclip.spectrogram.shape[1] == 512 diff --git a/tests/test_train/test_clips.py b/tests/test_train/test_clips.py deleted file mode 100644 index b6f7953..0000000 --- a/tests/test_train/test_clips.py +++ /dev/null @@ -1,27 +0,0 @@ -from soundevent import data - -from batdetect2.train import generate_train_example -from batdetect2.typing import ( - AudioLoader, - ClipLabeller, - ClipperProtocol, - PreprocessorProtocol, -) - - -def test_default_clip_size_is_correct( - sample_clipper: ClipperProtocol, - sample_labeller: ClipLabeller, - sample_audio_loader: AudioLoader, - clip_annotation: data.ClipAnnotation, - sample_preprocessor: PreprocessorProtocol, -): - example = generate_train_example( - clip_annotation=clip_annotation, - audio_loader=sample_audio_loader, - preprocessor=sample_preprocessor, - labeller=sample_labeller, - ) - - clip, _, _ = sample_clipper(example) - assert clip.spectrogram.shape == (1, 128, 256) diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 213d7ba..15e15e9 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -56,7 +56,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( clip_annotation, - torch.rand([100, 100]), + torch.rand([1, 100, 100]), min_freq=0, max_freq=100, targets=targets, @@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( assert size_heatmap[1, 10, 10] == 20 assert class_heatmap[pippip_index, 10, 10] == 1.0 assert class_heatmap[myomyo_index, 10, 10] == 0.0 - assert detection_heatmap[10, 10] == 1.0 + assert detection_heatmap[0, 10, 10] == 1.0 diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index af4bcad..830e1e4 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -4,14 +4,16 @@ import lightning as L import torch from soundevent import data +from batdetect2.models import build_model from batdetect2.train import FullTrainingConfig, TrainingModule from batdetect2.train.train import build_training_module from batdetect2.typing.preprocess import AudioLoader def build_default_module(): + model = build_model() config = FullTrainingConfig() - return build_training_module(config) + return build_training_module(model, config=config) def test_can_initialize_default_module(): @@ -32,14 +34,14 @@ def test_can_save_checkpoint( recovered = TrainingModule.load_from_checkpoint(path) - wav = torch.tensor(sample_audio_loader.load_clip(clip)) + wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0) spec1 = module.model.preprocessor(wav) spec2 = recovered.model.preprocessor(wav) torch.testing.assert_close(spec1, spec2, rtol=0, atol=0) - output1 = module(spec1.unsqueeze(0).unsqueeze(0)) - output2 = recovered(spec2.unsqueeze(0).unsqueeze(0)) + output1 = module(spec1.unsqueeze(0)) + output2 = recovered(spec2.unsqueeze(0)) torch.testing.assert_close(output1, output2, rtol=0, atol=0) diff --git a/tests/test_train/test_preprocessing.py b/tests/test_train/test_preprocessing.py deleted file mode 100644 index 0660705..0000000 --- a/tests/test_train/test_preprocessing.py +++ /dev/null @@ -1,230 +0,0 @@ -import pytest -from soundevent import data -from soundevent.terms import get_term - -from batdetect2.postprocess import build_postprocessor, load_postprocess_config -from batdetect2.preprocess import build_preprocessor, load_preprocessing_config -from batdetect2.targets import build_targets, load_target_config -from batdetect2.train.labels import build_clip_labeler, load_label_config -from batdetect2.train.preprocess import generate_train_example -from batdetect2.typing import ModelOutput -from batdetect2.typing.preprocess import AudioLoader - - -@pytest.fixture -def build_from_config( - create_temp_yaml, -): - def build(yaml_content): - config_path = create_temp_yaml(yaml_content) - - targets_config = load_target_config(config_path, field="targets") - preprocessing_config = load_preprocessing_config( - config_path, - field="preprocessing", - ) - labels_config = load_label_config(config_path, field="labels") - postprocessing_config = load_postprocess_config( - config_path, - field="postprocessing", - ) - - targets = build_targets(targets_config) - preprocessor = build_preprocessor(preprocessing_config) - labeller = build_clip_labeler( - targets=targets, - config=labels_config, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - ) - postprocessor = build_postprocessor( - preprocessor=preprocessor, - config=postprocessing_config, - ) - - return targets, preprocessor, labeller, postprocessor - - return build - - -def test_encoding_decoding_roundtrip_recovers_object( - sample_audio_loader: AudioLoader, - build_from_config, - recording, -): - yaml_content = """ - labels: - targets: - roi: - name: anchor_bbox - anchor: bottom-left - classes: - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - generic_class: - - key: order - value: Chiroptera - preprocessing: - """ - _, preprocessor, labeller, postprocessor = build_from_config(yaml_content) - - geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) - se1 = data.SoundEventAnnotation( - sound_event=data.SoundEvent(recording=recording, geometry=geometry), - tags=[ - data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore - ], - ) - clip = data.Clip(start_time=0, end_time=0.5, recording=recording) - clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) - - encoded = generate_train_example( - clip_annotation, - sample_audio_loader, - preprocessor, - labeller, - ) - predictions = postprocessor.get_predictions( - ModelOutput( - detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze( - 0 - ), - size_preds=encoded.size_heatmap.unsqueeze(0), - class_probs=encoded.class_heatmap.unsqueeze(0), - features=encoded.spectrogram.unsqueeze(0).unsqueeze(0), - ), - [clip], - )[0] - - assert isinstance(predictions, data.ClipPrediction) - assert len(predictions.sound_events) == 1 - - recovered = predictions.sound_events[0] - assert recovered.sound_event.geometry is not None - assert isinstance(recovered.sound_event.geometry, data.BoundingBox) - start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = ( - recovered.sound_event.geometry.coordinates - ) - start_time_or, low_freq_or, end_time_or, high_freq_or = ( - geometry.coordinates - ) - - assert start_time_rec == pytest.approx(start_time_or, abs=0.01) - assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000) - assert end_time_rec == pytest.approx(end_time_or, abs=0.01) - assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000) - - assert len(recovered.tags) == 2 - - predicted_species_tag = next( - iter(t for t in recovered.tags if t.tag.term == get_term("species")), - None, - ) - assert predicted_species_tag is not None - assert predicted_species_tag.score == 1 - assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus" - - predicted_order_tag = next( - iter(t for t in recovered.tags if t.tag.term == get_term("order")), - None, - ) - assert predicted_order_tag is not None - assert predicted_order_tag.score == 1 - assert predicted_order_tag.tag.value == "Chiroptera" - - -def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( - sample_audio_loader: AudioLoader, - build_from_config, - recording, -): - yaml_content = """ - labels: - targets: - roi: - name: anchor_bbox - anchor: bottom-left - classes: - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - - name: myomyo - tags: - - key: species - value: Myotis myotis - roi: - name: anchor_bbox - anchor: top-left - generic_class: - - key: order - value: Chiroptera - preprocessing: - """ - _, preprocessor, labeller, postprocessor = build_from_config(yaml_content) - - geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) - se1 = data.SoundEventAnnotation( - sound_event=data.SoundEvent(recording=recording, geometry=geometry), - tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore - ) - clip = data.Clip(start_time=0, end_time=0.5, recording=recording) - clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) - - encoded = generate_train_example( - clip_annotation, - sample_audio_loader, - preprocessor, - labeller, - ) - predictions = postprocessor.get_predictions( - ModelOutput( - detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze( - 0 - ), - size_preds=encoded.size_heatmap.unsqueeze(0), - class_probs=encoded.class_heatmap.unsqueeze(0), - features=encoded.spectrogram.unsqueeze(0).unsqueeze(0), - ), - [clip], - )[0] - - assert isinstance(predictions, data.ClipPrediction) - assert len(predictions.sound_events) == 1 - - recovered = predictions.sound_events[0] - assert recovered.sound_event.geometry is not None - assert isinstance(recovered.sound_event.geometry, data.BoundingBox) - start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = ( - recovered.sound_event.geometry.coordinates - ) - start_time_or, low_freq_or, end_time_or, high_freq_or = ( - geometry.coordinates - ) - - assert start_time_rec == pytest.approx(start_time_or, abs=0.01) - assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000) - assert end_time_rec == pytest.approx(end_time_or, abs=0.01) - assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000) - - assert len(recovered.tags) == 2 - - predicted_species_tag = next( - iter(t for t in recovered.tags if t.tag.term == get_term("species")), - None, - ) - assert predicted_species_tag is not None - assert predicted_species_tag.score == 1 - assert predicted_species_tag.tag.value == "Myotis myotis" - - predicted_order_tag = next( - iter(t for t in recovered.tags if t.tag.term == get_term("order")), - None, - ) - assert predicted_order_tag is not None - assert predicted_order_tag.score == 1 - assert predicted_order_tag.tag.value == "Chiroptera"