From 95a884ea16229127da83074b2987624ade067b55 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 8 Sep 2025 18:00:17 +0100 Subject: [PATCH] Update tests --- src/batdetect2/targets/classes.py | 21 +++++++++++++-------- tests/test_train/test_labels.py | 3 ++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 96c78f6..4c73e1a 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -33,7 +33,8 @@ class TargetClassConfig(BaseConfig): alias="match_if", default=None, ) - tags: Optional[List[data.Tag]] = None + + tags: Optional[List[data.Tag]] = Field(default=None, exclude=True) assign_tags: List[data.Tag] = Field(default_factory=list) @@ -42,20 +43,24 @@ class TargetClassConfig(BaseConfig): _match_if: SoundEventConditionConfig = PrivateAttr() @model_validator(mode="after") - def _process_shorthands(self) -> "TargetClassConfig": + def _process_tags(self) -> "TargetClassConfig": if self.tags and self.condition_input: raise ValueError("Use either 'tags' or 'match_if', not both.") - if self.condition_input: - final_condition = self.condition_input - elif self.tags: - final_condition = HasAllTagsConfig(tags=self.tags) - else: + if self.condition_input is not None: + self._match_if = self.condition_input + return self + + if self.tags is None: raise ValueError( f"Class '{self.name}' must have a 'tags' or 'match_if' rule." ) - self._match_if = final_condition + self._match_if = HasAllTagsConfig(tags=self.tags) + + if not self.assign_tags: + self.assign_tags = self.tags + return self @computed_field diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 9c27941..19c39ab 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -26,6 +26,7 @@ clip = data.Clip( def test_generated_heatmap_are_non_zero_at_correct_positions( sample_target_config: TargetConfig, pippip_tag: data.Tag, + bat_tag: data.Tag, ): config = sample_target_config.model_copy( update=dict( @@ -48,7 +49,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( coordinates=[10, 10, 20, 30], ), ), - tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore + tags=[pippip_tag, bat_tag], ) ], )