mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: support target config roundtrips
This commit is contained in:
parent
eec126a502
commit
20a7c058fc
@ -16,6 +16,7 @@ from batdetect2.data.conditions.common import (
|
||||
IdInListConfig,
|
||||
JsonList,
|
||||
ListFormatConfig,
|
||||
TagInfo,
|
||||
TxtList,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
@ -63,16 +64,17 @@ __all__ = [
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"PathInListConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingCondition",
|
||||
"RecordingConditionConfig",
|
||||
"RecordingConditionImportConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingNotConfig",
|
||||
"RecordingSatisfiesConfig",
|
||||
"SoundEventCondition",
|
||||
"SoundEventConditionConfig",
|
||||
"SoundEventConditionImportConfig",
|
||||
"TagInfo",
|
||||
"TxtList",
|
||||
"build_clip_annotation_condition",
|
||||
"build_recording_condition",
|
||||
|
||||
@ -2,10 +2,23 @@ import csv
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PlainSerializer,
|
||||
model_validator,
|
||||
)
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
@ -138,19 +151,26 @@ class IdInList(Generic[UUIDObject]):
|
||||
return obj.uuid in self.ids
|
||||
|
||||
|
||||
def dump_tag(tag: data.Tag) -> dict[str, Any]:
|
||||
return {"key": tag.term.name, "value": tag.value}
|
||||
|
||||
|
||||
TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)]
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
tag: TagInfo
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: list[data.Tag]
|
||||
tags: list[TagInfo]
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: list[data.Tag]
|
||||
tags: list[TagInfo]
|
||||
|
||||
|
||||
class JsonList(BaseConfig):
|
||||
|
||||
@ -12,6 +12,7 @@ from batdetect2.data.conditions import (
|
||||
NotConfig,
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
TagInfo,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
from batdetect2.targets.terms import call_type, generic_class
|
||||
@ -32,11 +33,12 @@ class TargetClassConfig(BaseConfig):
|
||||
condition_input: SoundEventConditionConfig | None = Field(
|
||||
alias="match_if",
|
||||
default=None,
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
||||
|
||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||
assign_tags: List[TagInfo] = Field(default_factory=list)
|
||||
|
||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||
|
||||
|
||||
@ -50,21 +50,31 @@ class Targets(TargetProtocol):
|
||||
self.config = config
|
||||
|
||||
self._filter_fn = build_sound_event_condition(
|
||||
config.detection_target.match_if
|
||||
self.config.detection_target.match_if
|
||||
)
|
||||
self._encode_fn = build_sound_event_encoder(
|
||||
config.classification_targets
|
||||
self.config.classification_targets
|
||||
)
|
||||
self._decode_fn = build_sound_event_decoder(
|
||||
config.classification_targets
|
||||
self.config.classification_targets
|
||||
)
|
||||
|
||||
self.class_names = get_class_names_from_config(
|
||||
config.classification_targets
|
||||
self.config.classification_targets
|
||||
)
|
||||
|
||||
self.detection_class_name = config.detection_target.name
|
||||
self.detection_class_tags = config.detection_target.assign_tags
|
||||
self.detection_class_name = self.config.detection_target.name
|
||||
self.detection_class_tags = self.config.detection_target.assign_tags
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "Targets":
|
||||
"""Build a Targets object from a serialized config dictionary."""
|
||||
validated_config = TargetConfig.model_validate(config)
|
||||
return cls(config=validated_config)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
"""Return the serialized target config used to build this object."""
|
||||
return self.config.model_dump(mode="json")
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
@ -131,7 +141,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
)
|
||||
|
||||
|
||||
def build_targets(config: TargetConfig | None = None) -> Targets:
|
||||
def build_targets(config: TargetConfig | dict | None = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
Parameters
|
||||
@ -153,6 +163,10 @@ def build_targets(config: TargetConfig | None = None) -> Targets:
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
config = config or DEFAULT_TARGET_CONFIG
|
||||
|
||||
if not isinstance(config, TargetConfig):
|
||||
config = TargetConfig.model_validate(config)
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building targets with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
|
||||
@ -28,6 +28,11 @@ class TargetProtocol(Protocol):
|
||||
detection_class_tags: list[data.Tag]
|
||||
detection_class_name: str
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "TargetProtocol": ...
|
||||
|
||||
def get_config(self) -> dict: ...
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||
|
||||
def encode_class(
|
||||
|
||||
@ -1,9 +1,34 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
Targets,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
|
||||
|
||||
def test_targets_get_config_returns_a_json_serializable_dict() -> None:
|
||||
targets = build_targets(TargetConfig())
|
||||
|
||||
config_dict = targets.get_config()
|
||||
assert isinstance(config_dict, dict)
|
||||
assert json.dumps(config_dict)
|
||||
|
||||
|
||||
def test_targets_from_config_rebuilds_equivalent_targets() -> None:
|
||||
original = build_targets(TargetConfig())
|
||||
|
||||
rebuilt = Targets.from_config(original.get_config())
|
||||
|
||||
assert rebuilt.class_names == original.class_names
|
||||
assert rebuilt.detection_class_name == original.detection_class_name
|
||||
assert rebuilt.detection_class_tags == original.detection_class_tags
|
||||
assert rebuilt.get_config() == original.get_config()
|
||||
|
||||
|
||||
def test_can_override_default_roi_mapper_per_class(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user