mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Compare commits
2 Commits
ad5293e0d0
...
8628133fd7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8628133fd7 | ||
|
|
d80377981e |
@ -47,7 +47,7 @@ def evaluate(
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
config=config.train.val_loader,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
@ -66,8 +66,9 @@ def evaluate(
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
clips=[
|
||||
clip_annotation.clip for clip_annotation in clip_annotations
|
||||
start_times=[
|
||||
clip_annotation.clip.start_time
|
||||
for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=model.postprocessor,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol):
|
||||
|
||||
|
||||
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str], per_class: bool = True):
|
||||
def __init__(self, class_names: List[str]):
|
||||
self.class_names = class_names
|
||||
self.per_class = per_class
|
||||
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true = label_binarize(
|
||||
@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
for match in matches
|
||||
]
|
||||
).fillna(0)
|
||||
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
||||
|
||||
ret = {
|
||||
"classification_mAP": float(mAP),
|
||||
}
|
||||
|
||||
if not self.per_class:
|
||||
return ret
|
||||
ret = {}
|
||||
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
)
|
||||
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
||||
|
||||
ret["classification_mAP"] = np.mean(
|
||||
[value for value in ret.values() if value != 0]
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives[:n_examples]):
|
||||
@ -112,7 +112,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives[:n_examples]):
|
||||
@ -124,7 +124,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||
|
||||
@ -191,7 +191,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
@ -203,7 +203,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if clips is None:
|
||||
if start_times is None:
|
||||
return detections
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
@ -211,12 +211,12 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.start_time + duration,
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection, clip in zip(detections, clips)
|
||||
for detection, start_time in zip(detections, start_times)
|
||||
]
|
||||
|
||||
|
||||
@ -224,28 +224,10 @@ def get_raw_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes raw model output through remapping, NMS, detection, data
|
||||
extraction, and geometry recovery via the configured
|
||||
`targets.recover_roi`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
List of lists (one inner list per input clip). Each inner list
|
||||
contains `RawPrediction` objects for detections in that clip.
|
||||
"""
|
||||
detections = postprocessor.get_detections(output, clips)
|
||||
"""Extract intermediate RawPrediction objects for a batch."""
|
||||
detections = postprocessor.get_detections(output, start_times)
|
||||
return [
|
||||
to_raw_predictions(detection.numpy(), targets=targets)
|
||||
for detection in detections
|
||||
@ -254,16 +236,16 @@ def get_raw_predictions(
|
||||
|
||||
def get_sound_event_predictions(
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: List[data.Clip],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
clips=clips,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
[
|
||||
@ -312,7 +294,7 @@ def get_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
clips=clips,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
|
||||
@ -126,14 +126,15 @@ class ValidationMetrics(Callback):
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
dataset.get_clip_annotation(int(example_idx))
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
clips=[
|
||||
clip_annotation.clip for clip_annotation in clip_annotations
|
||||
start_times=[
|
||||
clip_annotation.clip.start_time
|
||||
for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
|
||||
@ -101,8 +101,15 @@ class ValidationDataset(Dataset):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
wav, clip_annotation = self.load_audio(idx)
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
clip = clip_annotation.clip
|
||||
wav = torch.tensor(
|
||||
self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
).unsqueeze(0)
|
||||
|
||||
spectrogram = self.preprocessor(wav)
|
||||
|
||||
@ -117,17 +124,3 @@ class ValidationDataset(Dataset):
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
return clip_annotation
|
||||
|
||||
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||
clip_annotation = self.get_clip_annotation(idx)
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
return torch.tensor(wav).unsqueeze(0), clip_annotation
|
||||
|
||||
@ -97,5 +97,5 @@ class PostprocessorProtocol(Protocol):
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user