mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: support target config roundtrips
This commit is contained in:
parent
eec126a502
commit
20a7c058fc
@ -16,6 +16,7 @@ from batdetect2.data.conditions.common import (
|
|||||||
IdInListConfig,
|
IdInListConfig,
|
||||||
JsonList,
|
JsonList,
|
||||||
ListFormatConfig,
|
ListFormatConfig,
|
||||||
|
TagInfo,
|
||||||
TxtList,
|
TxtList,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions.recordings import (
|
from batdetect2.data.conditions.recordings import (
|
||||||
@ -63,16 +64,17 @@ __all__ = [
|
|||||||
"NotConfig",
|
"NotConfig",
|
||||||
"Operator",
|
"Operator",
|
||||||
"PathInListConfig",
|
"PathInListConfig",
|
||||||
|
"RecordingAllOfConfig",
|
||||||
|
"RecordingAnyOfConfig",
|
||||||
"RecordingCondition",
|
"RecordingCondition",
|
||||||
"RecordingConditionConfig",
|
"RecordingConditionConfig",
|
||||||
"RecordingConditionImportConfig",
|
"RecordingConditionImportConfig",
|
||||||
"RecordingAllOfConfig",
|
|
||||||
"RecordingAnyOfConfig",
|
|
||||||
"RecordingNotConfig",
|
"RecordingNotConfig",
|
||||||
"RecordingSatisfiesConfig",
|
"RecordingSatisfiesConfig",
|
||||||
"SoundEventCondition",
|
"SoundEventCondition",
|
||||||
"SoundEventConditionConfig",
|
"SoundEventConditionConfig",
|
||||||
"SoundEventConditionImportConfig",
|
"SoundEventConditionImportConfig",
|
||||||
|
"TagInfo",
|
||||||
"TxtList",
|
"TxtList",
|
||||||
"build_clip_annotation_condition",
|
"build_clip_annotation_condition",
|
||||||
"build_recording_condition",
|
"build_recording_condition",
|
||||||
|
|||||||
@ -2,10 +2,23 @@ import csv
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Any,
|
||||||
|
Generic,
|
||||||
|
Literal,
|
||||||
|
ParamSpec,
|
||||||
|
Protocol,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
PlainSerializer,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
@ -138,19 +151,26 @@ class IdInList(Generic[UUIDObject]):
|
|||||||
return obj.uuid in self.ids
|
return obj.uuid in self.ids
|
||||||
|
|
||||||
|
|
||||||
|
def dump_tag(tag: data.Tag) -> dict[str, Any]:
|
||||||
|
return {"key": tag.term.name, "value": tag.value}
|
||||||
|
|
||||||
|
|
||||||
|
TagInfo = Annotated[data.Tag, PlainSerializer(dump_tag)]
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
class HasTagConfig(BaseConfig):
|
||||||
name: Literal["has_tag"] = "has_tag"
|
name: Literal["has_tag"] = "has_tag"
|
||||||
tag: data.Tag
|
tag: TagInfo
|
||||||
|
|
||||||
|
|
||||||
class HasAllTagsConfig(BaseConfig):
|
class HasAllTagsConfig(BaseConfig):
|
||||||
name: Literal["has_all_tags"] = "has_all_tags"
|
name: Literal["has_all_tags"] = "has_all_tags"
|
||||||
tags: list[data.Tag]
|
tags: list[TagInfo]
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTagConfig(BaseConfig):
|
class HasAnyTagConfig(BaseConfig):
|
||||||
name: Literal["has_any_tag"] = "has_any_tag"
|
name: Literal["has_any_tag"] = "has_any_tag"
|
||||||
tags: list[data.Tag]
|
tags: list[TagInfo]
|
||||||
|
|
||||||
|
|
||||||
class JsonList(BaseConfig):
|
class JsonList(BaseConfig):
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from batdetect2.data.conditions import (
|
|||||||
NotConfig,
|
NotConfig,
|
||||||
SoundEventCondition,
|
SoundEventCondition,
|
||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
|
TagInfo,
|
||||||
build_sound_event_condition,
|
build_sound_event_condition,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import call_type, generic_class
|
from batdetect2.targets.terms import call_type, generic_class
|
||||||
@ -32,11 +33,12 @@ class TargetClassConfig(BaseConfig):
|
|||||||
condition_input: SoundEventConditionConfig | None = Field(
|
condition_input: SoundEventConditionConfig | None = Field(
|
||||||
alias="match_if",
|
alias="match_if",
|
||||||
default=None,
|
default=None,
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
tags: List[data.Tag] | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
assign_tags: List[TagInfo] = Field(default_factory=list)
|
||||||
|
|
||||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||||
|
|
||||||
|
|||||||
@ -50,21 +50,31 @@ class Targets(TargetProtocol):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self._filter_fn = build_sound_event_condition(
|
self._filter_fn = build_sound_event_condition(
|
||||||
config.detection_target.match_if
|
self.config.detection_target.match_if
|
||||||
)
|
)
|
||||||
self._encode_fn = build_sound_event_encoder(
|
self._encode_fn = build_sound_event_encoder(
|
||||||
config.classification_targets
|
self.config.classification_targets
|
||||||
)
|
)
|
||||||
self._decode_fn = build_sound_event_decoder(
|
self._decode_fn = build_sound_event_decoder(
|
||||||
config.classification_targets
|
self.config.classification_targets
|
||||||
)
|
)
|
||||||
|
|
||||||
self.class_names = get_class_names_from_config(
|
self.class_names = get_class_names_from_config(
|
||||||
config.classification_targets
|
self.config.classification_targets
|
||||||
)
|
)
|
||||||
|
|
||||||
self.detection_class_name = config.detection_target.name
|
self.detection_class_name = self.config.detection_target.name
|
||||||
self.detection_class_tags = config.detection_target.assign_tags
|
self.detection_class_tags = self.config.detection_target.assign_tags
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: dict) -> "Targets":
|
||||||
|
"""Build a Targets object from a serialized config dictionary."""
|
||||||
|
validated_config = TargetConfig.model_validate(config)
|
||||||
|
return cls(config=validated_config)
|
||||||
|
|
||||||
|
def get_config(self) -> dict:
|
||||||
|
"""Return the serialized target config used to build this object."""
|
||||||
|
return self.config.model_dump(mode="json")
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
"""Apply the configured filter to a sound event annotation.
|
"""Apply the configured filter to a sound event annotation.
|
||||||
@ -131,7 +141,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_targets(config: TargetConfig | None = None) -> Targets:
|
def build_targets(config: TargetConfig | dict | None = None) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -153,6 +163,10 @@ def build_targets(config: TargetConfig | None = None) -> Targets:
|
|||||||
If dynamic import of a derivation function fails (when configured).
|
If dynamic import of a derivation function fails (when configured).
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_TARGET_CONFIG
|
config = config or DEFAULT_TARGET_CONFIG
|
||||||
|
|
||||||
|
if not isinstance(config, TargetConfig):
|
||||||
|
config = TargetConfig.model_validate(config)
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building targets with config: \n{}",
|
"Building targets with config: \n{}",
|
||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
|
|||||||
@ -28,6 +28,11 @@ class TargetProtocol(Protocol):
|
|||||||
detection_class_tags: list[data.Tag]
|
detection_class_tags: list[data.Tag]
|
||||||
detection_class_name: str
|
detection_class_name: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: dict) -> "TargetProtocol": ...
|
||||||
|
|
||||||
|
def get_config(self) -> dict: ...
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||||
|
|
||||||
def encode_class(
|
def encode_class(
|
||||||
|
|||||||
@ -1,9 +1,34 @@
|
|||||||
|
import json
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
from batdetect2.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
Targets,
|
||||||
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_targets_get_config_returns_a_json_serializable_dict() -> None:
|
||||||
|
targets = build_targets(TargetConfig())
|
||||||
|
|
||||||
|
config_dict = targets.get_config()
|
||||||
|
assert isinstance(config_dict, dict)
|
||||||
|
assert json.dumps(config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_targets_from_config_rebuilds_equivalent_targets() -> None:
|
||||||
|
original = build_targets(TargetConfig())
|
||||||
|
|
||||||
|
rebuilt = Targets.from_config(original.get_config())
|
||||||
|
|
||||||
|
assert rebuilt.class_names == original.class_names
|
||||||
|
assert rebuilt.detection_class_name == original.detection_class_name
|
||||||
|
assert rebuilt.detection_class_tags == original.detection_class_tags
|
||||||
|
assert rebuilt.get_config() == original.get_config()
|
||||||
|
|
||||||
|
|
||||||
def test_can_override_default_roi_mapper_per_class(
|
def test_can_override_default_roi_mapper_per_class(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user