diff --git a/src/batdetect2/data/conditions/__init__.py b/src/batdetect2/data/conditions/__init__.py index 6721393..1d451ba 100644 --- a/src/batdetect2/data/conditions/__init__.py +++ b/src/batdetect2/data/conditions/__init__.py @@ -16,6 +16,7 @@ from batdetect2.data.conditions.common import ( IdInListConfig, JsonList, ListFormatConfig, + TagInfo, TxtList, ) from batdetect2.data.conditions.recordings import ( @@ -63,16 +64,17 @@ __all__ = [ "NotConfig", "Operator", "PathInListConfig", + "RecordingAllOfConfig", + "RecordingAnyOfConfig", "RecordingCondition", "RecordingConditionConfig", "RecordingConditionImportConfig", - "RecordingAllOfConfig", - "RecordingAnyOfConfig", "RecordingNotConfig", "RecordingSatisfiesConfig", "SoundEventCondition", "SoundEventConditionConfig", "SoundEventConditionImportConfig", + "TagInfo", "TxtList", "build_clip_annotation_condition", "build_recording_condition", diff --git a/src/batdetect2/data/conditions/common.py b/src/batdetect2/data/conditions/common.py index f340ba2..b2fdf21 100644 --- a/src/batdetect2/data/conditions/common.py +++ b/src/batdetect2/data/conditions/common.py @@ -2,10 +2,23 @@ import csv import json from collections.abc import Callable, Sequence from pathlib import Path -from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar +from typing import ( + Annotated, + Any, + Generic, + Literal, + ParamSpec, + Protocol, + TypeVar, +) from uuid import UUID -from pydantic import BaseModel, Field, model_validator +from pydantic import ( + BaseModel, + Field, + PlainSerializer, + model_validator, +) from soundevent import data from batdetect2.core.configs import BaseConfig @@ -138,19 +151,26 @@ class IdInList(Generic[UUIDObject]): return obj.uuid in self.ids +def dump_tag(tag: data.Tag) -> dict[str, Any]: + return {"key": tag.term.name, "value": tag.value} + + +TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)] + + class HasTagConfig(BaseConfig): name: Literal["has_tag"] = "has_tag" - tag: data.Tag + tag: TagInfo class HasAllTagsConfig(BaseConfig): name: Literal["has_all_tags"] = "has_all_tags" - tags: list[data.Tag] + tags: list[TagInfo] class HasAnyTagConfig(BaseConfig): name: Literal["has_any_tag"] = "has_any_tag" - tags: list[data.Tag] + tags: list[TagInfo] class JsonList(BaseConfig): diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 7639cf0..ee6e0a3 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -12,6 +12,7 @@ from batdetect2.data.conditions import ( NotConfig, SoundEventCondition, SoundEventConditionConfig, + TagInfo, build_sound_event_condition, ) from batdetect2.targets.terms import call_type, generic_class @@ -32,11 +33,12 @@ class TargetClassConfig(BaseConfig): condition_input: SoundEventConditionConfig | None = Field( alias="match_if", default=None, + exclude=True, ) tags: List[data.Tag] | None = Field(default=None, exclude=True) - assign_tags: List[data.Tag] = Field(default_factory=list) + assign_tags: List[TagInfo] = Field(default_factory=list) _match_if: SoundEventConditionConfig = PrivateAttr() diff --git a/src/batdetect2/targets/targets.py b/src/batdetect2/targets/targets.py index 72e0262..9c8bb1e 100644 --- a/src/batdetect2/targets/targets.py +++ b/src/batdetect2/targets/targets.py @@ -50,21 +50,31 @@ class Targets(TargetProtocol): self.config = config self._filter_fn = build_sound_event_condition( - config.detection_target.match_if + self.config.detection_target.match_if ) self._encode_fn = build_sound_event_encoder( - config.classification_targets + self.config.classification_targets ) self._decode_fn = build_sound_event_decoder( - config.classification_targets + self.config.classification_targets ) self.class_names = get_class_names_from_config( - config.classification_targets + self.config.classification_targets ) - self.detection_class_name = config.detection_target.name - self.detection_class_tags = config.detection_target.assign_tags + self.detection_class_name = self.config.detection_target.name + self.detection_class_tags = self.config.detection_target.assign_tags + + @classmethod + def from_config(cls, config: dict) -> "Targets": + """Build a Targets object from a serialized config dictionary.""" + validated_config = TargetConfig.model_validate(config) + return cls(config=validated_config) + + def get_config(self) -> dict: + """Return the serialized target config used to build this object.""" + return self.config.model_dump(mode="json") def filter(self, sound_event: data.SoundEventAnnotation) -> bool: """Apply the configured filter to a sound event annotation. @@ -131,7 +141,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( ) -def build_targets(config: TargetConfig | None = None) -> Targets: +def build_targets(config: TargetConfig | dict | None = None) -> Targets: """Build a Targets object from a loaded TargetConfig. Parameters @@ -153,6 +163,10 @@ def build_targets(config: TargetConfig | None = None) -> Targets: If dynamic import of a derivation function fails (when configured). """ config = config or DEFAULT_TARGET_CONFIG + + if not isinstance(config, TargetConfig): + config = TargetConfig.model_validate(config) + logger.opt(lazy=True).debug( "Building targets with config: \n{}", lambda: config.to_yaml_string(), diff --git a/src/batdetect2/targets/types.py b/src/batdetect2/targets/types.py index 4f435ba..0558b3c 100644 --- a/src/batdetect2/targets/types.py +++ b/src/batdetect2/targets/types.py @@ -28,6 +28,11 @@ class TargetProtocol(Protocol): detection_class_tags: list[data.Tag] detection_class_name: str + @classmethod + def from_config(cls, config: dict) -> "TargetProtocol": ... + + def get_config(self) -> dict: ... + def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ... def encode_class( diff --git a/tests/test_targets/test_targets.py b/tests/test_targets/test_targets.py index 72c9c7a..37d1f3a 100644 --- a/tests/test_targets/test_targets.py +++ b/tests/test_targets/test_targets.py @@ -1,9 +1,34 @@ +import json from collections.abc import Callable from pathlib import Path from soundevent import data, terms -from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets +from batdetect2.targets import ( + TargetConfig, + Targets, + build_roi_mapping, + build_targets, +) + + +def test_targets_get_config_returns_a_json_serializable_dict() -> None: + targets = build_targets(TargetConfig()) + + config_dict = targets.get_config() + assert isinstance(config_dict, dict) + assert json.dumps(config_dict) + + +def test_targets_from_config_rebuilds_equivalent_targets() -> None: + original = build_targets(TargetConfig()) + + rebuilt = Targets.from_config(original.get_config()) + + assert rebuilt.class_names == original.class_names + assert rebuilt.detection_class_name == original.detection_class_name + assert rebuilt.detection_class_tags == original.detection_class_tags + assert rebuilt.get_config() == original.get_config() def test_can_override_default_roi_mapper_per_class(