diff --git a/src/batdetect2/data/annotations/batdetect2.py b/src/batdetect2/data/annotations/batdetect2.py index fe5697a..1d17d81 100644 --- a/src/batdetect2/data/annotations/batdetect2.py +++ b/src/batdetect2/data/annotations/batdetect2.py @@ -301,7 +301,8 @@ def load_batdetect2_merged_annotated_dataset( for ann in content: try: ann = FileAnnotation.model_validate(ann) - except ValueError: + except ValueError as err: + logger.warning(f"Invalid annotation file: {err}") continue if ( @@ -309,14 +310,17 @@ def load_batdetect2_merged_annotated_dataset( and dataset.filter.only_annotated and not ann.annotated ): + logger.debug(f"Skipping incomplete annotation {ann.id}") continue if dataset.filter and dataset.filter.exclude_issues and ann.issues: + logger.debug(f"Skipping annotation with issues {ann.id}") continue try: clip = file_annotation_to_clip(ann, audio_dir=audio_dir) - except FileNotFoundError: + except FileNotFoundError as err: + logger.warning(f"Error loading annotations: {err}") continue annotations.append(file_annotation_to_clip_annotation(ann, clip)) diff --git a/src/batdetect2/data/summary.py b/src/batdetect2/data/summary.py index f7828b5..f4994d5 100644 --- a/src/batdetect2/data/summary.py +++ b/src/batdetect2/data/summary.py @@ -100,8 +100,11 @@ def extract_sound_events_df( class_name = targets.encode_class(sound_event) - if class_name is None and exclude_generic: - continue + if class_name is None: + if exclude_generic: + continue + else: + class_name = targets.detection_class_name start_time, low_freq, end_time, high_freq = compute_bounds( sound_event.sound_event.geometry @@ -153,7 +156,7 @@ def compute_class_summary( sound_events = extract_sound_events_df( dataset, targets, - exclude_generic=True, + exclude_generic=False, exclude_non_target=True, ) recordings = extract_recordings_df(dataset) diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 9180d34..6802aa8 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -103,7 +103,7 @@ def convert_raw_prediction_to_sound_event_prediction( tags = [ *get_generic_tags( raw_prediction.detection_score, - generic_class_tags=targets.generic_class_tags, + generic_class_tags=targets.detection_class_tags, ), *get_class_tags( raw_prediction.class_scores, diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index b6eb193..1da4163 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -140,16 +140,18 @@ class Targets(TargetProtocol): """ class_names: List[str] - generic_class_tags: List[data.Tag] + detection_class_tags: List[data.Tag] dimension_names: List[str] + detection_class_name: str def __init__( self, + detection_class_name: str, encode_fn: SoundEventEncoder, decode_fn: SoundEventDecoder, roi_mapper: ROITargetMapper, class_names: list[str], - generic_class_tags: List[data.Tag], + detection_class_tags: List[data.Tag], filter_fn: Optional[SoundEventCondition] = None, roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None, ): @@ -175,8 +177,9 @@ class Targets(TargetProtocol): transform_fn : SoundEventTransformation, optional Configured function to transform annotation tags. Defaults to None. """ + self.detection_class_name = detection_class_name self.class_names = class_names - self.generic_class_tags = generic_class_tags + self.detection_class_tags = detection_class_tags self.dimension_names = roi_mapper.dimension_names self._roi_mapper = roi_mapper @@ -381,7 +384,8 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets: decode_fn=decode_fn, class_names=class_names, roi_mapper=roi_mapper, - generic_class_tags=generic_class_tags, + detection_class_name=config.detection_target.name, + detection_class_tags=generic_class_tags, roi_mapper_overrides=roi_overrides, ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 98b8808..7d0f21d 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -158,25 +158,13 @@ def build_trainer_callbacks( if run_name is not None: checkpoint_dir = checkpoint_dir / run_name - filename = "best-{epoch:02d}-{val_loss:.0f}" - - if run_name is not None: - filename = f"run_{run_name}_{filename}" - - if experiment_name is not None: - filename = f"experiment_{experiment_name}_{filename}" - - model_checkpoint = ModelCheckpoint( - dirpath=str(checkpoint_dir), - save_top_k=1, - filename=filename, - monitor="total_loss/val", - ) - - model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore - return [ - model_checkpoint, + ModelCheckpoint( + dirpath=str(checkpoint_dir), + save_top_k=1, + filename="best-{epoch:02d}-{val_loss:.0f}", + monitor="total_loss/val", + ), ValidationMetrics( metrics=[ DetectionAveragePrecision(), @@ -226,7 +214,8 @@ def build_trainer( config=conf.evaluation, preprocessor=build_preprocessor(conf.preprocess), checkpoint_dir=checkpoint_dir, - experiment_name=train_logger.name, + experiment_name=experiment_name, + run_name=run_name, ), ) diff --git a/src/batdetect2/typing/targets.py b/src/batdetect2/typing/targets.py index f3148a6..b86573c 100644 --- a/src/batdetect2/typing/targets.py +++ b/src/batdetect2/typing/targets.py @@ -94,8 +94,10 @@ class TargetProtocol(Protocol): class_names: List[str] """Ordered list of unique names for the specific target classes.""" - generic_class_tags: List[data.Tag] - """List of tags representing the generic (unclassified) category.""" + detection_class_tags: List[data.Tag] + """List of tags representing the detection category (unclassified).""" + + detection_class_name: str dimension_names: List[str] """Names of the size dimensions (e.g., ['width', 'height'])."""