Add detection_class_name to targets protocol

This commit is contained in:
mbsantiago 2025-09-09 20:30:20 +01:00
parent 41b18c3f0a
commit 615c811bb4
6 changed files with 33 additions and 31 deletions

View File

@ -301,7 +301,8 @@ def load_batdetect2_merged_annotated_dataset(
for ann in content:
try:
ann = FileAnnotation.model_validate(ann)
except ValueError:
except ValueError as err:
logger.warning(f"Invalid annotation file: {err}")
continue
if (
@ -309,14 +310,17 @@ def load_batdetect2_merged_annotated_dataset(
and dataset.filter.only_annotated
and not ann.annotated
):
logger.debug(f"Skipping incomplete annotation {ann.id}")
continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
logger.debug(f"Skipping annotation with issues {ann.id}")
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError:
except FileNotFoundError as err:
logger.warning(f"Error loading annotations: {err}")
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))

View File

@ -100,8 +100,11 @@ def extract_sound_events_df(
class_name = targets.encode_class(sound_event)
if class_name is None and exclude_generic:
continue
if class_name is None:
if exclude_generic:
continue
else:
class_name = targets.detection_class_name
start_time, low_freq, end_time, high_freq = compute_bounds(
sound_event.sound_event.geometry
@ -153,7 +156,7 @@ def compute_class_summary(
sound_events = extract_sound_events_df(
dataset,
targets,
exclude_generic=True,
exclude_generic=False,
exclude_non_target=True,
)
recordings = extract_recordings_df(dataset)

View File

@ -103,7 +103,7 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [
*get_generic_tags(
raw_prediction.detection_score,
generic_class_tags=targets.generic_class_tags,
generic_class_tags=targets.detection_class_tags,
),
*get_class_tags(
raw_prediction.class_scores,

View File

@ -140,16 +140,18 @@ class Targets(TargetProtocol):
"""
class_names: List[str]
generic_class_tags: List[data.Tag]
detection_class_tags: List[data.Tag]
dimension_names: List[str]
detection_class_name: str
def __init__(
self,
detection_class_name: str,
encode_fn: SoundEventEncoder,
decode_fn: SoundEventDecoder,
roi_mapper: ROITargetMapper,
class_names: list[str],
generic_class_tags: List[data.Tag],
detection_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventCondition] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
):
@ -175,8 +177,9 @@ class Targets(TargetProtocol):
transform_fn : SoundEventTransformation, optional
Configured function to transform annotation tags. Defaults to None.
"""
self.detection_class_name = detection_class_name
self.class_names = class_names
self.generic_class_tags = generic_class_tags
self.detection_class_tags = detection_class_tags
self.dimension_names = roi_mapper.dimension_names
self._roi_mapper = roi_mapper
@ -381,7 +384,8 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
decode_fn=decode_fn,
class_names=class_names,
roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags,
detection_class_name=config.detection_target.name,
detection_class_tags=generic_class_tags,
roi_mapper_overrides=roi_overrides,
)

View File

@ -158,25 +158,13 @@ def build_trainer_callbacks(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
filename = "best-{epoch:02d}-{val_loss:.0f}"
if run_name is not None:
filename = f"run_{run_name}_{filename}"
if experiment_name is not None:
filename = f"experiment_{experiment_name}_{filename}"
model_checkpoint = ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
filename=filename,
monitor="total_loss/val",
)
model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore
return [
model_checkpoint,
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
filename="best-{epoch:02d}-{val_loss:.0f}",
monitor="total_loss/val",
),
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
@ -226,7 +214,8 @@ def build_trainer(
config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir,
experiment_name=train_logger.name,
experiment_name=experiment_name,
run_name=run_name,
),
)

View File

@ -94,8 +94,10 @@ class TargetProtocol(Protocol):
class_names: List[str]
"""Ordered list of unique names for the specific target classes."""
generic_class_tags: List[data.Tag]
"""List of tags representing the generic (unclassified) category."""
detection_class_tags: List[data.Tag]
"""List of tags representing the detection category (unclassified)."""
detection_class_name: str
dimension_names: List[str]
"""Names of the size dimensions (e.g., ['width', 'height'])."""