feat: support target config roundtrips

This commit is contained in:
mbsantiago 2026-05-04 22:31:32 +01:00
parent eec126a502
commit 20a7c058fc
6 changed files with 84 additions and 16 deletions

View File

@ -16,6 +16,7 @@ from batdetect2.data.conditions.common import (
IdInListConfig, IdInListConfig,
JsonList, JsonList,
ListFormatConfig, ListFormatConfig,
TagInfo,
TxtList, TxtList,
) )
from batdetect2.data.conditions.recordings import ( from batdetect2.data.conditions.recordings import (
@ -63,16 +64,17 @@ __all__ = [
"NotConfig", "NotConfig",
"Operator", "Operator",
"PathInListConfig", "PathInListConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingCondition", "RecordingCondition",
"RecordingConditionConfig", "RecordingConditionConfig",
"RecordingConditionImportConfig", "RecordingConditionImportConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingNotConfig", "RecordingNotConfig",
"RecordingSatisfiesConfig", "RecordingSatisfiesConfig",
"SoundEventCondition", "SoundEventCondition",
"SoundEventConditionConfig", "SoundEventConditionConfig",
"SoundEventConditionImportConfig", "SoundEventConditionImportConfig",
"TagInfo",
"TxtList", "TxtList",
"build_clip_annotation_condition", "build_clip_annotation_condition",
"build_recording_condition", "build_recording_condition",

View File

@ -2,10 +2,23 @@ import csv
import json import json
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from pathlib import Path from pathlib import Path
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar from typing import (
Annotated,
Any,
Generic,
Literal,
ParamSpec,
Protocol,
TypeVar,
)
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field, model_validator from pydantic import (
BaseModel,
Field,
PlainSerializer,
model_validator,
)
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
@ -138,19 +151,26 @@ class IdInList(Generic[UUIDObject]):
return obj.uuid in self.ids return obj.uuid in self.ids
def dump_tag(tag: data.Tag) -> dict[str, Any]:
return {"key": tag.term.name, "value": tag.value}
TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)]
class HasTagConfig(BaseConfig): class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag" name: Literal["has_tag"] = "has_tag"
tag: data.Tag tag: TagInfo
class HasAllTagsConfig(BaseConfig): class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags" name: Literal["has_all_tags"] = "has_all_tags"
tags: list[data.Tag] tags: list[TagInfo]
class HasAnyTagConfig(BaseConfig): class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag" name: Literal["has_any_tag"] = "has_any_tag"
tags: list[data.Tag] tags: list[TagInfo]
class JsonList(BaseConfig): class JsonList(BaseConfig):

View File

@ -12,6 +12,7 @@ from batdetect2.data.conditions import (
NotConfig, NotConfig,
SoundEventCondition, SoundEventCondition,
SoundEventConditionConfig, SoundEventConditionConfig,
TagInfo,
build_sound_event_condition, build_sound_event_condition,
) )
from batdetect2.targets.terms import call_type, generic_class from batdetect2.targets.terms import call_type, generic_class
@ -32,11 +33,12 @@ class TargetClassConfig(BaseConfig):
condition_input: SoundEventConditionConfig | None = Field( condition_input: SoundEventConditionConfig | None = Field(
alias="match_if", alias="match_if",
default=None, default=None,
exclude=True,
) )
tags: List[data.Tag] | None = Field(default=None, exclude=True) tags: List[data.Tag] | None = Field(default=None, exclude=True)
assign_tags: List[data.Tag] = Field(default_factory=list) assign_tags: List[TagInfo] = Field(default_factory=list)
_match_if: SoundEventConditionConfig = PrivateAttr() _match_if: SoundEventConditionConfig = PrivateAttr()

View File

@ -50,21 +50,31 @@ class Targets(TargetProtocol):
self.config = config self.config = config
self._filter_fn = build_sound_event_condition( self._filter_fn = build_sound_event_condition(
config.detection_target.match_if self.config.detection_target.match_if
) )
self._encode_fn = build_sound_event_encoder( self._encode_fn = build_sound_event_encoder(
config.classification_targets self.config.classification_targets
) )
self._decode_fn = build_sound_event_decoder( self._decode_fn = build_sound_event_decoder(
config.classification_targets self.config.classification_targets
) )
self.class_names = get_class_names_from_config( self.class_names = get_class_names_from_config(
config.classification_targets self.config.classification_targets
) )
self.detection_class_name = config.detection_target.name self.detection_class_name = self.config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags self.detection_class_tags = self.config.detection_target.assign_tags
@classmethod
def from_config(cls, config: dict) -> "Targets":
"""Build a Targets object from a serialized config dictionary."""
validated_config = TargetConfig.model_validate(config)
return cls(config=validated_config)
def get_config(self) -> dict:
"""Return the serialized target config used to build this object."""
return self.config.model_dump(mode="json")
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation. """Apply the configured filter to a sound event annotation.
@ -131,7 +141,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
) )
def build_targets(config: TargetConfig | None = None) -> Targets: def build_targets(config: TargetConfig | dict | None = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig. """Build a Targets object from a loaded TargetConfig.
Parameters Parameters
@ -153,6 +163,10 @@ def build_targets(config: TargetConfig | None = None) -> Targets:
If dynamic import of a derivation function fails (when configured). If dynamic import of a derivation function fails (when configured).
""" """
config = config or DEFAULT_TARGET_CONFIG config = config or DEFAULT_TARGET_CONFIG
if not isinstance(config, TargetConfig):
config = TargetConfig.model_validate(config)
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building targets with config: \n{}", "Building targets with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),

View File

@ -28,6 +28,11 @@ class TargetProtocol(Protocol):
detection_class_tags: list[data.Tag] detection_class_tags: list[data.Tag]
detection_class_name: str detection_class_name: str
@classmethod
def from_config(cls, config: dict) -> "TargetProtocol": ...
def get_config(self) -> dict: ...
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ... def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
def encode_class( def encode_class(

View File

@ -1,9 +1,34 @@
import json
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
from soundevent import data, terms from soundevent import data, terms
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets from batdetect2.targets import (
TargetConfig,
Targets,
build_roi_mapping,
build_targets,
)
def test_targets_get_config_returns_a_json_serializable_dict() -> None:
targets = build_targets(TargetConfig())
config_dict = targets.get_config()
assert isinstance(config_dict, dict)
assert json.dumps(config_dict)
def test_targets_from_config_rebuilds_equivalent_targets() -> None:
original = build_targets(TargetConfig())
rebuilt = Targets.from_config(original.get_config())
assert rebuilt.class_names == original.class_names
assert rebuilt.detection_class_name == original.detection_class_name
assert rebuilt.detection_class_tags == original.detection_class_tags
assert rebuilt.get_config() == original.get_config()
def test_can_override_default_roi_mapper_per_class( def test_can_override_default_roi_mapper_per_class(