Compare commits

...

2 Commits

Author SHA1 Message Date
mbsantiago
8628133fd7 Compute mAP 2025-09-14 10:08:51 +01:00
mbsantiago
d80377981e Fix plotting 2025-09-14 09:38:45 +01:00
7 changed files with 38 additions and 63 deletions

View File

@ -47,7 +47,7 @@ def evaluate(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train, config=config.train.val_loader,
num_workers=num_workers, num_workers=num_workers,
) )
@ -66,8 +66,9 @@ def evaluate(
predictions = get_raw_predictions( predictions = get_raw_predictions(
outputs, outputs,
clips=[ start_times=[
clip_annotation.clip for clip_annotation in clip_annotations clip_annotation.clip.start_time
for clip_annotation in clip_annotations
], ],
targets=targets, targets=targets,
postprocessor=model.postprocessor, postprocessor=model.postprocessor,

View File

@ -1,5 +1,6 @@
from typing import Dict, List from typing import Dict, List
import numpy as np
import pandas as pd import pandas as pd
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol):
class ClassificationMeanAveragePrecision(MetricsProtocol): class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str], per_class: bool = True): def __init__(self, class_names: List[str]):
self.class_names = class_names self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize( y_true = label_binarize(
@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
for match in matches for match in matches
] ]
).fillna(0) ).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = { ret = {}
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
for class_index, class_name in enumerate(self.class_names): for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index] y_true_class = y_true[:, class_index]
@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
) )
ret[f"classification_AP/{class_name}"] = float(class_ap) ret[f"classification_AP/{class_name}"] = float(class_ap)
ret["classification_mAP"] = np.mean(
[value for value in ret.values() if value != 0]
)
return ret return ret

View File

@ -100,7 +100,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError): except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue continue
for index, match in enumerate(false_positives[:n_examples]): for index, match in enumerate(false_positives[:n_examples]):
@ -112,7 +112,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError): except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue continue
for index, match in enumerate(false_negatives[:n_examples]): for index, match in enumerate(false_negatives[:n_examples]):
@ -124,7 +124,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError): except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue continue
for index, match in enumerate(cross_triggers[:n_examples]): for index, match in enumerate(cross_triggers[:n_examples]):

View File

@ -191,7 +191,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def get_detections( def get_detections(
self, self,
output: ModelOutput, output: ModelOutput,
clips: Optional[List[data.Clip]] = None, start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]: ) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1] width = output.detection_probs.shape[-1]
duration = width / self.samplerate duration = width / self.samplerate
@ -203,7 +203,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
threshold=self.detection_threshold, threshold=self.detection_threshold,
) )
if clips is None: if start_times is None:
return detections return detections
width = output.detection_probs.shape[-1] width = output.detection_probs.shape[-1]
@ -211,12 +211,12 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
return [ return [
map_detection_to_clip( map_detection_to_clip(
detection, detection,
start_time=clip.start_time, start_time=start_time,
end_time=clip.start_time + duration, end_time=start_time + duration,
min_freq=self.min_freq, min_freq=self.min_freq,
max_freq=self.max_freq, max_freq=self.max_freq,
) )
for detection, clip in zip(detections, clips) for detection, start_time in zip(detections, start_times)
] ]
@ -224,28 +224,10 @@ def get_raw_predictions(
output: ModelOutput, output: ModelOutput,
targets: TargetProtocol, targets: TargetProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
clips: Optional[List[data.Clip]] = None, start_times: Optional[List[float]] = None,
) -> List[List[RawPrediction]]: ) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch. """Extract intermediate RawPrediction objects for a batch."""
detections = postprocessor.get_detections(output, start_times)
Processes raw model output through remapping, NMS, detection, data
extraction, and geometry recovery via the configured
`targets.recover_roi`.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[List[RawPrediction]]
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detections = postprocessor.get_detections(output, clips)
return [ return [
to_raw_predictions(detection.numpy(), targets=targets) to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections for detection in detections
@ -254,16 +236,16 @@ def get_raw_predictions(
def get_sound_event_predictions( def get_sound_event_predictions(
output: ModelOutput, output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol, targets: TargetProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
clips: List[data.Clip],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]: ) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions( raw_predictions = get_raw_predictions(
output, output,
targets=targets, targets=targets,
postprocessor=postprocessor, postprocessor=postprocessor,
clips=clips, start_times=[clip.start_time for clip in clips],
) )
return [ return [
[ [
@ -312,7 +294,7 @@ def get_predictions(
output, output,
targets=targets, targets=targets,
postprocessor=postprocessor, postprocessor=postprocessor,
clips=clips, start_times=[clip.start_time for clip in clips],
) )
return [ return [
convert_raw_predictions_to_clip_prediction( convert_raw_predictions_to_clip_prediction(

View File

@ -126,14 +126,15 @@ class ValidationMetrics(Callback):
dataset = self.get_dataset(trainer) dataset = self.get_dataset(trainer)
clip_annotations = [ clip_annotations = [
dataset.get_clip_annotation(int(example_idx)) dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx for example_idx in batch.idx
] ]
predictions = get_raw_predictions( predictions = get_raw_predictions(
outputs, outputs,
clips=[ start_times=[
clip_annotation.clip for clip_annotation in clip_annotations clip_annotation.clip.start_time
for clip_annotation in clip_annotations
], ],
targets=targets, targets=targets,
postprocessor=postprocessor, postprocessor=postprocessor,

View File

@ -101,8 +101,15 @@ class ValidationDataset(Dataset):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
wav, clip_annotation = self.load_audio(idx) clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip clip = clip_annotation.clip
wav = torch.tensor(
self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
).unsqueeze(0)
spectrogram = self.preprocessor(wav) spectrogram = self.preprocessor(wav)
@ -117,17 +124,3 @@ class ValidationDataset(Dataset):
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
return clip_annotation
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
clip_annotation = self.get_clip_annotation(idx)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
return torch.tensor(wav).unsqueeze(0), clip_annotation

View File

@ -97,5 +97,5 @@ class PostprocessorProtocol(Protocol):
def get_detections( def get_detections(
self, self,
output: ModelOutput, output: ModelOutput,
clips: Optional[List[data.Clip]] = None, start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]: ... ) -> List[DetectionsTensor]: ...