Update tests

This commit is contained in:
mbsantiago 2025-09-08 18:00:17 +01:00
parent b7ae526071
commit 95a884ea16
2 changed files with 15 additions and 9 deletions

View File

@ -33,7 +33,8 @@ class TargetClassConfig(BaseConfig):
alias="match_if", alias="match_if",
default=None, default=None,
) )
tags: Optional[List[data.Tag]] = None
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
assign_tags: List[data.Tag] = Field(default_factory=list) assign_tags: List[data.Tag] = Field(default_factory=list)
@ -42,20 +43,24 @@ class TargetClassConfig(BaseConfig):
_match_if: SoundEventConditionConfig = PrivateAttr() _match_if: SoundEventConditionConfig = PrivateAttr()
@model_validator(mode="after") @model_validator(mode="after")
def _process_shorthands(self) -> "TargetClassConfig": def _process_tags(self) -> "TargetClassConfig":
if self.tags and self.condition_input: if self.tags and self.condition_input:
raise ValueError("Use either 'tags' or 'match_if', not both.") raise ValueError("Use either 'tags' or 'match_if', not both.")
if self.condition_input: if self.condition_input is not None:
final_condition = self.condition_input self._match_if = self.condition_input
elif self.tags: return self
final_condition = HasAllTagsConfig(tags=self.tags)
else: if self.tags is None:
raise ValueError( raise ValueError(
f"Class '{self.name}' must have a 'tags' or 'match_if' rule." f"Class '{self.name}' must have a 'tags' or 'match_if' rule."
) )
self._match_if = final_condition self._match_if = HasAllTagsConfig(tags=self.tags)
if not self.assign_tags:
self.assign_tags = self.tags
return self return self
@computed_field @computed_field

View File

@ -26,6 +26,7 @@ clip = data.Clip(
def test_generated_heatmap_are_non_zero_at_correct_positions( def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig, sample_target_config: TargetConfig,
pippip_tag: data.Tag, pippip_tag: data.Tag,
bat_tag: data.Tag,
): ):
config = sample_target_config.model_copy( config = sample_target_config.model_copy(
update=dict( update=dict(
@ -48,7 +49,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
coordinates=[10, 10, 20, 30], coordinates=[10, 10, 20, 30],
), ),
), ),
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore tags=[pippip_tag, bat_tag],
) )
], ],
) )