mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
2 Commits
ad5293e0d0
...
8628133fd7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8628133fd7 | ||
|
|
d80377981e |
@ -47,7 +47,7 @@ def evaluate(
|
|||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.train,
|
config=config.train.val_loader,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,8 +66,9 @@ def evaluate(
|
|||||||
|
|
||||||
predictions = get_raw_predictions(
|
predictions = get_raw_predictions(
|
||||||
outputs,
|
outputs,
|
||||||
clips=[
|
start_times=[
|
||||||
clip_annotation.clip for clip_annotation in clip_annotations
|
clip_annotation.clip.start_time
|
||||||
|
for clip_annotation in clip_annotations
|
||||||
],
|
],
|
||||||
targets=targets,
|
targets=targets,
|
||||||
postprocessor=model.postprocessor,
|
postprocessor=model.postprocessor,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from sklearn.preprocessing import label_binarize
|
from sklearn.preprocessing import label_binarize
|
||||||
@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol):
|
|||||||
|
|
||||||
|
|
||||||
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||||
def __init__(self, class_names: List[str], per_class: bool = True):
|
def __init__(self, class_names: List[str]):
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
self.per_class = per_class
|
|
||||||
|
|
||||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
y_true = label_binarize(
|
y_true = label_binarize(
|
||||||
@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
for match in matches
|
for match in matches
|
||||||
]
|
]
|
||||||
).fillna(0)
|
).fillna(0)
|
||||||
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
|
||||||
|
|
||||||
ret = {
|
ret = {}
|
||||||
"classification_mAP": float(mAP),
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.per_class:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
y_true_class = y_true[:, class_index]
|
y_true_class = y_true[:, class_index]
|
||||||
@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
)
|
)
|
||||||
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
||||||
|
|
||||||
|
ret["classification_mAP"] = np.mean(
|
||||||
|
[value for value in ret.values() if value != 0]
|
||||||
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -100,7 +100,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError):
|
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_positives[:n_examples]):
|
for index, match in enumerate(false_positives[:n_examples]):
|
||||||
@ -112,7 +112,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError):
|
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_negatives[:n_examples]):
|
for index, match in enumerate(false_negatives[:n_examples]):
|
||||||
@ -124,7 +124,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError):
|
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||||
|
|||||||
@ -191,7 +191,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
def get_detections(
|
def get_detections(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: Optional[List[data.Clip]] = None,
|
start_times: Optional[List[float]] = None,
|
||||||
) -> List[DetectionsTensor]:
|
) -> List[DetectionsTensor]:
|
||||||
width = output.detection_probs.shape[-1]
|
width = output.detection_probs.shape[-1]
|
||||||
duration = width / self.samplerate
|
duration = width / self.samplerate
|
||||||
@ -203,7 +203,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
threshold=self.detection_threshold,
|
threshold=self.detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
if clips is None:
|
if start_times is None:
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
width = output.detection_probs.shape[-1]
|
width = output.detection_probs.shape[-1]
|
||||||
@ -211,12 +211,12 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
return [
|
return [
|
||||||
map_detection_to_clip(
|
map_detection_to_clip(
|
||||||
detection,
|
detection,
|
||||||
start_time=clip.start_time,
|
start_time=start_time,
|
||||||
end_time=clip.start_time + duration,
|
end_time=start_time + duration,
|
||||||
min_freq=self.min_freq,
|
min_freq=self.min_freq,
|
||||||
max_freq=self.max_freq,
|
max_freq=self.max_freq,
|
||||||
)
|
)
|
||||||
for detection, clip in zip(detections, clips)
|
for detection, start_time in zip(detections, start_times)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -224,28 +224,10 @@ def get_raw_predictions(
|
|||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
clips: Optional[List[data.Clip]] = None,
|
start_times: Optional[List[float]] = None,
|
||||||
) -> List[List[RawPrediction]]:
|
) -> List[List[RawPrediction]]:
|
||||||
"""Extract intermediate RawPrediction objects for a batch.
|
"""Extract intermediate RawPrediction objects for a batch."""
|
||||||
|
detections = postprocessor.get_detections(output, start_times)
|
||||||
Processes raw model output through remapping, NMS, detection, data
|
|
||||||
extraction, and geometry recovery via the configured
|
|
||||||
`targets.recover_roi`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
Raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[List[RawPrediction]]
|
|
||||||
List of lists (one inner list per input clip). Each inner list
|
|
||||||
contains `RawPrediction` objects for detections in that clip.
|
|
||||||
"""
|
|
||||||
detections = postprocessor.get_detections(output, clips)
|
|
||||||
return [
|
return [
|
||||||
to_raw_predictions(detection.numpy(), targets=targets)
|
to_raw_predictions(detection.numpy(), targets=targets)
|
||||||
for detection in detections
|
for detection in detections
|
||||||
@ -254,16 +236,16 @@ def get_raw_predictions(
|
|||||||
|
|
||||||
def get_sound_event_predictions(
|
def get_sound_event_predictions(
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: List[data.Clip],
|
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
|
clips: List[data.Clip],
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
) -> List[List[BatDetect2Prediction]]:
|
) -> List[List[BatDetect2Prediction]]:
|
||||||
raw_predictions = get_raw_predictions(
|
raw_predictions = get_raw_predictions(
|
||||||
output,
|
output,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
clips=clips,
|
start_times=[clip.start_time for clip in clips],
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
[
|
[
|
||||||
@ -312,7 +294,7 @@ def get_predictions(
|
|||||||
output,
|
output,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
clips=clips,
|
start_times=[clip.start_time for clip in clips],
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
convert_raw_predictions_to_clip_prediction(
|
convert_raw_predictions_to_clip_prediction(
|
||||||
|
|||||||
@ -126,14 +126,15 @@ class ValidationMetrics(Callback):
|
|||||||
dataset = self.get_dataset(trainer)
|
dataset = self.get_dataset(trainer)
|
||||||
|
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
dataset.get_clip_annotation(int(example_idx))
|
dataset.clip_annotations[int(example_idx)]
|
||||||
for example_idx in batch.idx
|
for example_idx in batch.idx
|
||||||
]
|
]
|
||||||
|
|
||||||
predictions = get_raw_predictions(
|
predictions = get_raw_predictions(
|
||||||
outputs,
|
outputs,
|
||||||
clips=[
|
start_times=[
|
||||||
clip_annotation.clip for clip_annotation in clip_annotations
|
clip_annotation.clip.start_time
|
||||||
|
for clip_annotation in clip_annotations
|
||||||
],
|
],
|
||||||
targets=targets,
|
targets=targets,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
|
|||||||
@ -101,8 +101,15 @@ class ValidationDataset(Dataset):
|
|||||||
return len(self.clip_annotations)
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
wav, clip_annotation = self.load_audio(idx)
|
clip_annotation = self.clip_annotations[idx]
|
||||||
|
|
||||||
|
if self.clipper is not None:
|
||||||
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
|
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
wav = torch.tensor(
|
||||||
|
self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
spectrogram = self.preprocessor(wav)
|
spectrogram = self.preprocessor(wav)
|
||||||
|
|
||||||
@ -117,17 +124,3 @@ class ValidationDataset(Dataset):
|
|||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
|
|
||||||
clip_annotation = self.clip_annotations[idx]
|
|
||||||
|
|
||||||
if self.clipper is not None:
|
|
||||||
clip_annotation = self.clipper(clip_annotation)
|
|
||||||
|
|
||||||
return clip_annotation
|
|
||||||
|
|
||||||
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
|
||||||
clip_annotation = self.get_clip_annotation(idx)
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
|
||||||
return torch.tensor(wav).unsqueeze(0), clip_annotation
|
|
||||||
|
|||||||
@ -97,5 +97,5 @@ class PostprocessorProtocol(Protocol):
|
|||||||
def get_detections(
|
def get_detections(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: Optional[List[data.Clip]] = None,
|
start_times: Optional[List[float]] = None,
|
||||||
) -> List[DetectionsTensor]: ...
|
) -> List[DetectionsTensor]: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user