Compare commits

...

2 Commits

Author SHA1 Message Date
mbsantiago
8628133fd7 Compute mAP 2025-09-14 10:08:51 +01:00
mbsantiago
d80377981e Fix plotting 2025-09-14 09:38:45 +01:00
7 changed files with 38 additions and 63 deletions

View File

@ -47,7 +47,7 @@ def evaluate(
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train,
config=config.train.val_loader,
num_workers=num_workers,
)
@ -66,8 +66,9 @@ def evaluate(
predictions = get_raw_predictions(
outputs,
clips=[
clip_annotation.clip for clip_annotation in clip_annotations
start_times=[
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,

View File

@ -1,5 +1,6 @@
from typing import Dict, List
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol):
class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str], per_class: bool = True):
def __init__(self, class_names: List[str]):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize(
@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
for match in matches
]
).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = {
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
ret = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
)
ret[f"classification_AP/{class_name}"] = float(class_ap)
ret["classification_mAP"] = np.mean(
[value for value in ret.values() if value != 0]
)
return ret

View File

@ -100,7 +100,7 @@ def plot_class_examples(
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError):
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_positives[:n_examples]):
@ -112,7 +112,7 @@ def plot_class_examples(
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError):
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(false_negatives[:n_examples]):
@ -124,7 +124,7 @@ def plot_class_examples(
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError):
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue
for index, match in enumerate(cross_triggers[:n_examples]):

View File

@ -191,7 +191,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def get_detections(
self,
output: ModelOutput,
clips: Optional[List[data.Clip]] = None,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
@ -203,7 +203,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
threshold=self.detection_threshold,
)
if clips is None:
if start_times is None:
return detections
width = output.detection_probs.shape[-1]
@ -211,12 +211,12 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
return [
map_detection_to_clip(
detection,
start_time=clip.start_time,
end_time=clip.start_time + duration,
start_time=start_time,
end_time=start_time + duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, clip in zip(detections, clips)
for detection, start_time in zip(detections, start_times)
]
@ -224,28 +224,10 @@ def get_raw_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
clips: Optional[List[data.Clip]] = None,
start_times: Optional[List[float]] = None,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data
extraction, and geometry recovery via the configured
`targets.recover_roi`.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[List[RawPrediction]]
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detections = postprocessor.get_detections(output, clips)
"""Extract intermediate RawPrediction objects for a batch."""
detections = postprocessor.get_detections(output, start_times)
return [
to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections
@ -254,16 +236,16 @@ def get_raw_predictions(
def get_sound_event_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
clips: List[data.Clip],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
clips=clips,
start_times=[clip.start_time for clip in clips],
)
return [
[
@ -312,7 +294,7 @@ def get_predictions(
output,
targets=targets,
postprocessor=postprocessor,
clips=clips,
start_times=[clip.start_time for clip in clips],
)
return [
convert_raw_predictions_to_clip_prediction(

View File

@ -126,14 +126,15 @@ class ValidationMetrics(Callback):
dataset = self.get_dataset(trainer)
clip_annotations = [
dataset.get_clip_annotation(int(example_idx))
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
predictions = get_raw_predictions(
outputs,
clips=[
clip_annotation.clip for clip_annotation in clip_annotations
start_times=[
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=postprocessor,

View File

@ -101,8 +101,15 @@ class ValidationDataset(Dataset):
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
wav, clip_annotation = self.load_audio(idx)
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = torch.tensor(
self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
).unsqueeze(0)
spectrogram = self.preprocessor(wav)
@ -117,17 +124,3 @@ class ValidationDataset(Dataset):
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
return clip_annotation
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
clip_annotation = self.get_clip_annotation(idx)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
return torch.tensor(wav).unsqueeze(0), clip_annotation

View File

@ -97,5 +97,5 @@ class PostprocessorProtocol(Protocol):
def get_detections(
self,
output: ModelOutput,
clips: Optional[List[data.Clip]] = None,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]: ...