feat: support target config roundtrips

This commit is contained in:
mbsantiago 2026-05-04 22:31:32 +01:00
parent eec126a502
commit 20a7c058fc
6 changed files with 84 additions and 16 deletions

View File

@ -16,6 +16,7 @@ from batdetect2.data.conditions.common import (
IdInListConfig,
JsonList,
ListFormatConfig,
TagInfo,
TxtList,
)
from batdetect2.data.conditions.recordings import (
@ -63,16 +64,17 @@ __all__ = [
"NotConfig",
"Operator",
"PathInListConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingCondition",
"RecordingConditionConfig",
"RecordingConditionImportConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingNotConfig",
"RecordingSatisfiesConfig",
"SoundEventCondition",
"SoundEventConditionConfig",
"SoundEventConditionImportConfig",
"TagInfo",
"TxtList",
"build_clip_annotation_condition",
"build_recording_condition",

View File

@ -2,10 +2,23 @@ import csv
import json
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
from typing import (
Annotated,
Any,
Generic,
Literal,
ParamSpec,
Protocol,
TypeVar,
)
from uuid import UUID
from pydantic import BaseModel, Field, model_validator
from pydantic import (
BaseModel,
Field,
PlainSerializer,
model_validator,
)
from soundevent import data
from batdetect2.core.configs import BaseConfig
@ -138,19 +151,26 @@ class IdInList(Generic[UUIDObject]):
return obj.uuid in self.ids
def dump_tag(tag: data.Tag) -> dict[str, Any]:
return {"key": tag.term.name, "value": tag.value}
TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)]
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
tag: TagInfo
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: list[data.Tag]
tags: list[TagInfo]
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: list[data.Tag]
tags: list[TagInfo]
class JsonList(BaseConfig):

View File

@ -12,6 +12,7 @@ from batdetect2.data.conditions import (
NotConfig,
SoundEventCondition,
SoundEventConditionConfig,
TagInfo,
build_sound_event_condition,
)
from batdetect2.targets.terms import call_type, generic_class
@ -32,11 +33,12 @@ class TargetClassConfig(BaseConfig):
condition_input: SoundEventConditionConfig | None = Field(
alias="match_if",
default=None,
exclude=True,
)
tags: List[data.Tag] | None = Field(default=None, exclude=True)
assign_tags: List[data.Tag] = Field(default_factory=list)
assign_tags: List[TagInfo] = Field(default_factory=list)
_match_if: SoundEventConditionConfig = PrivateAttr()

View File

@ -50,21 +50,31 @@ class Targets(TargetProtocol):
self.config = config
self._filter_fn = build_sound_event_condition(
config.detection_target.match_if
self.config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
self.config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
self.config.classification_targets
)
self.class_names = get_class_names_from_config(
config.classification_targets
self.config.classification_targets
)
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self.detection_class_name = self.config.detection_target.name
self.detection_class_tags = self.config.detection_target.assign_tags
@classmethod
def from_config(cls, config: dict) -> "Targets":
"""Build a Targets object from a serialized config dictionary."""
validated_config = TargetConfig.model_validate(config)
return cls(config=validated_config)
def get_config(self) -> dict:
"""Return the serialized target config used to build this object."""
return self.config.model_dump(mode="json")
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
@ -131,7 +141,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
)
def build_targets(config: TargetConfig | None = None) -> Targets:
def build_targets(config: TargetConfig | dict | None = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
Parameters
@ -153,6 +163,10 @@ def build_targets(config: TargetConfig | None = None) -> Targets:
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
if not isinstance(config, TargetConfig):
config = TargetConfig.model_validate(config)
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),

View File

@ -28,6 +28,11 @@ class TargetProtocol(Protocol):
detection_class_tags: list[data.Tag]
detection_class_name: str
@classmethod
def from_config(cls, config: dict) -> "TargetProtocol": ...
def get_config(self) -> dict: ...
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
def encode_class(

View File

@ -1,9 +1,34 @@
import json
from collections.abc import Callable
from pathlib import Path
from soundevent import data, terms
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.targets import (
TargetConfig,
Targets,
build_roi_mapping,
build_targets,
)
def test_targets_get_config_returns_a_json_serializable_dict() -> None:
targets = build_targets(TargetConfig())
config_dict = targets.get_config()
assert isinstance(config_dict, dict)
assert json.dumps(config_dict)
def test_targets_from_config_rebuilds_equivalent_targets() -> None:
original = build_targets(TargetConfig())
rebuilt = Targets.from_config(original.get_config())
assert rebuilt.class_names == original.class_names
assert rebuilt.detection_class_name == original.detection_class_name
assert rebuilt.detection_class_tags == original.detection_class_tags
assert rebuilt.get_config() == original.get_config()
def test_can_override_default_roi_mapper_per_class(