mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Adding evaluation callback
This commit is contained in:
parent
9106b9f408
commit
bc86c94f8e
@ -1,9 +1,13 @@
|
|||||||
from batdetect2.evaluate.evaluate import (
|
from batdetect2.evaluate.evaluate import (
|
||||||
compute_error_auc,
|
compute_error_auc,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.match import (
|
||||||
match_predictions_and_annotations,
|
match_predictions_and_annotations,
|
||||||
|
match_sound_events_and_raw_predictions,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"compute_error_auc",
|
"compute_error_auc",
|
||||||
"match_predictions_and_annotations",
|
"match_predictions_and_annotations",
|
||||||
|
"match_sound_events_and_raw_predictions",
|
||||||
]
|
]
|
||||||
|
@ -1,51 +1,6 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.metrics import auc, roc_curve
|
from sklearn.metrics import auc, roc_curve
|
||||||
from soundevent import data
|
|
||||||
from soundevent.evaluation import match_geometries
|
|
||||||
|
|
||||||
|
|
||||||
def match_predictions_and_annotations(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
clip_prediction: data.ClipPrediction,
|
|
||||||
) -> List[data.Match]:
|
|
||||||
annotated_sound_events = [
|
|
||||||
sound_event_annotation
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
|
||||||
if sound_event_annotation.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_sound_events = [
|
|
||||||
sound_event_prediction
|
|
||||||
for sound_event_prediction in clip_prediction.sound_events
|
|
||||||
if sound_event_prediction.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
annotated_geometries: List[data.Geometry] = [
|
|
||||||
sound_event.sound_event.geometry
|
|
||||||
for sound_event in annotated_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_geometries: List[data.Geometry] = [
|
|
||||||
sound_event.sound_event.geometry
|
|
||||||
for sound_event in predicted_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for id1, id2, affinity in match_geometries(
|
|
||||||
annotated_geometries,
|
|
||||||
predicted_geometries,
|
|
||||||
):
|
|
||||||
target = annotated_sound_events[id1] if id1 is not None else None
|
|
||||||
source = predicted_sound_events[id2] if id2 is not None else None
|
|
||||||
matches.append(
|
|
||||||
data.Match(source=source, target=target, affinity=affinity)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
|
|
||||||
def compute_error_auc(op_str, gt, pred, prob):
|
def compute_error_auc(op_str, gt, pred, prob):
|
||||||
|
111
batdetect2/evaluate/match.py
Normal file
111
batdetect2/evaluate/match.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.evaluation import match_geometries
|
||||||
|
|
||||||
|
from batdetect2.evaluate.types import Match
|
||||||
|
from batdetect2.postprocess.types import RawPrediction
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
from batdetect2.utils.arrays import iterate_over_array
|
||||||
|
|
||||||
|
|
||||||
|
def match_sound_events_and_raw_predictions(
|
||||||
|
sound_events: List[data.SoundEventAnnotation],
|
||||||
|
raw_predictions: List[RawPrediction],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> List[Match]:
|
||||||
|
target_sound_events = [
|
||||||
|
targets.transform(sound_event_annotation)
|
||||||
|
for sound_event_annotation in sound_events
|
||||||
|
if targets.filter(sound_event_annotation)
|
||||||
|
and sound_event_annotation.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||||
|
sound_event_annotation.sound_event.geometry
|
||||||
|
for sound_event_annotation in target_sound_events
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_geometries = [
|
||||||
|
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for id1, id2, affinity in match_geometries(
|
||||||
|
target_geometries,
|
||||||
|
predicted_geometries,
|
||||||
|
):
|
||||||
|
target = target_sound_events[id1] if id1 is not None else None
|
||||||
|
prediction = raw_predictions[id2] if id2 is not None else None
|
||||||
|
|
||||||
|
gt_uuid = target.uuid if target is not None else None
|
||||||
|
gt_det = target is not None
|
||||||
|
gt_class = targets.encode(target) if target is not None else None
|
||||||
|
|
||||||
|
pred_score = float(prediction.detection_score) if prediction else 0
|
||||||
|
|
||||||
|
class_scores = (
|
||||||
|
{
|
||||||
|
str(class_name): float(score)
|
||||||
|
for class_name, score in iterate_over_array(
|
||||||
|
prediction.class_scores
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if prediction is not None
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
Match(
|
||||||
|
gt_uuid=gt_uuid,
|
||||||
|
gt_det=gt_det,
|
||||||
|
gt_class=gt_class,
|
||||||
|
pred_score=pred_score,
|
||||||
|
affinity=affinity,
|
||||||
|
class_scores=class_scores,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def match_predictions_and_annotations(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
clip_prediction: data.ClipPrediction,
|
||||||
|
) -> List[data.Match]:
|
||||||
|
annotated_sound_events = [
|
||||||
|
sound_event_annotation
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_sound_events = [
|
||||||
|
sound_event_prediction
|
||||||
|
for sound_event_prediction in clip_prediction.sound_events
|
||||||
|
if sound_event_prediction.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
annotated_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in annotated_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in predicted_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for id1, id2, affinity in match_geometries(
|
||||||
|
annotated_geometries,
|
||||||
|
predicted_geometries,
|
||||||
|
):
|
||||||
|
target = annotated_sound_events[id1] if id1 is not None else None
|
||||||
|
source = predicted_sound_events[id2] if id2 is not None else None
|
||||||
|
matches.append(
|
||||||
|
data.Match(source=source, target=target, affinity=affinity)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
39
batdetect2/evaluate/metrics.py
Normal file
39
batdetect2/evaluate/metrics.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn import metrics
|
||||||
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
|
from batdetect2.evaluate.types import Match, MetricsProtocol
|
||||||
|
|
||||||
|
__all__ = ["DetectionAveragePrecision"]
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionAveragePrecision(MetricsProtocol):
|
||||||
|
name: str = "detection/average_precision"
|
||||||
|
|
||||||
|
def __call__(self, matches: List[Match]) -> float:
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[(match.gt_det, match.pred_score) for match in matches]
|
||||||
|
)
|
||||||
|
return float(metrics.average_precision_score(y_true, y_score))
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||||
|
name: str = "classification/average_precision"
|
||||||
|
|
||||||
|
def __init__(self, class_names: List[str]):
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def __call__(self, matches: List[Match]) -> float:
|
||||||
|
y_true = label_binarize(
|
||||||
|
[
|
||||||
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
|
for match in matches
|
||||||
|
],
|
||||||
|
classes=self.class_names,
|
||||||
|
)
|
||||||
|
y_pred = pd.DataFrame([match.class_scores for match in matches])
|
||||||
|
return float(
|
||||||
|
metrics.average_precision_score(y_true, y_pred[self.class_names])
|
||||||
|
)
|
24
batdetect2/evaluate/types.py
Normal file
24
batdetect2/evaluate/types.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Protocol
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MetricsProtocol",
|
||||||
|
"Match",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Match:
|
||||||
|
gt_uuid: Optional[UUID]
|
||||||
|
gt_det: bool
|
||||||
|
gt_class: Optional[str]
|
||||||
|
pred_score: float
|
||||||
|
affinity: float
|
||||||
|
class_scores: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsProtocol(Protocol):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def __call__(self, matches: List[Match]) -> float: ...
|
@ -170,8 +170,8 @@ def load_postprocess_config(
|
|||||||
def build_postprocessor(
|
def build_postprocessor(
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: Optional[PostprocessConfig] = None,
|
config: Optional[PostprocessConfig] = None,
|
||||||
max_freq: int = MAX_FREQ,
|
max_freq: float = MAX_FREQ,
|
||||||
min_freq: int = MIN_FREQ,
|
min_freq: float = MIN_FREQ,
|
||||||
) -> PostprocessorProtocol:
|
) -> PostprocessorProtocol:
|
||||||
"""Factory function to build the standard postprocessor.
|
"""Factory function to build the standard postprocessor.
|
||||||
|
|
||||||
@ -234,9 +234,9 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
recovery.
|
recovery.
|
||||||
config : PostprocessConfig
|
config : PostprocessConfig
|
||||||
Configuration object holding parameters for NMS, thresholds, etc.
|
Configuration object holding parameters for NMS, thresholds, etc.
|
||||||
min_freq : int
|
min_freq : float
|
||||||
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
max_freq : int
|
max_freq : float
|
||||||
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -246,8 +246,8 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: PostprocessConfig,
|
config: PostprocessConfig,
|
||||||
min_freq: int = MIN_FREQ,
|
min_freq: float = MIN_FREQ,
|
||||||
max_freq: int = MAX_FREQ,
|
max_freq: float = MAX_FREQ,
|
||||||
):
|
):
|
||||||
"""Initialize the Postprocessor.
|
"""Initialize the Postprocessor.
|
||||||
|
|
||||||
|
@ -32,10 +32,10 @@ from typing import List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||||
from batdetect2.targets.classes import SoundEventDecoder
|
from batdetect2.targets.classes import SoundEventDecoder
|
||||||
|
from batdetect2.utils.arrays import iterate_over_array
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_xr_dataset_to_raw_prediction",
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
@ -97,18 +97,14 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
det_info = detection_dataset.sel(detection=det_num)
|
det_info = detection_dataset.sel(detection=det_num)
|
||||||
|
|
||||||
geom = geometry_builder(
|
geom = geometry_builder(
|
||||||
(det_info.time, det_info.freq),
|
(det_info.time, det_info.frequency),
|
||||||
det_info.dimensions,
|
det_info.dimensions,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
|
|
||||||
detections.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=det_info.score,
|
detection_score=det_info.scores,
|
||||||
start_time=start_time,
|
geometry=geom,
|
||||||
end_time=end_time,
|
|
||||||
low_freq=low_freq,
|
|
||||||
high_freq=high_freq,
|
|
||||||
class_scores=det_info.classes,
|
class_scores=det_info.classes,
|
||||||
features=det_info.features,
|
features=det_info.features,
|
||||||
)
|
)
|
||||||
@ -244,14 +240,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
"""
|
"""
|
||||||
sound_event = data.SoundEvent(
|
sound_event = data.SoundEvent(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
geometry=data.BoundingBox(
|
geometry=raw_prediction.geometry,
|
||||||
coordinates=[
|
|
||||||
raw_prediction.start_time,
|
|
||||||
raw_prediction.low_freq,
|
|
||||||
raw_prediction.end_time,
|
|
||||||
raw_prediction.high_freq,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
features=get_prediction_features(raw_prediction.features),
|
features=get_prediction_features(raw_prediction.features),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -333,7 +322,7 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
|||||||
),
|
),
|
||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
for feat_name, value in _iterate_over_array(features)
|
for feat_name, value in iterate_over_array(features)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -394,13 +383,6 @@ def get_class_tags(
|
|||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
def _iterate_over_array(array: xr.DataArray):
|
|
||||||
dim_name = array.dims[0]
|
|
||||||
coords = array.coords[dim_name]
|
|
||||||
for value, coord in zip(array.values, coords.values):
|
|
||||||
yield coord, float(value)
|
|
||||||
|
|
||||||
|
|
||||||
def _iterate_sorted(array: xr.DataArray):
|
def _iterate_sorted(array: xr.DataArray):
|
||||||
dim_name = array.dims[0]
|
dim_name = array.dims[0]
|
||||||
coords = array.coords[dim_name].values
|
coords = array.coords[dim_name].values
|
||||||
|
@ -47,14 +47,9 @@ class RawPrediction(NamedTuple):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
start_time : float
|
geometry: data.Geometry
|
||||||
Start time of the recovered bounding box in seconds.
|
The recovered estimated geometry of the detected sound event.
|
||||||
end_time : float
|
Usually a bounding box.
|
||||||
End time of the recovered bounding box in seconds.
|
|
||||||
low_freq : float
|
|
||||||
Lowest frequency of the recovered bounding box in Hz.
|
|
||||||
high_freq : float
|
|
||||||
Highest frequency of the recovered bounding box in Hz.
|
|
||||||
detection_score : float
|
detection_score : float
|
||||||
The confidence score associated with this detection, typically from
|
The confidence score associated with this detection, typically from
|
||||||
the detection heatmap peak.
|
the detection heatmap peak.
|
||||||
@ -67,10 +62,7 @@ class RawPrediction(NamedTuple):
|
|||||||
detection location. Indexed by a 'feature' coordinate.
|
detection location. Indexed by a 'feature' coordinate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_time: float
|
geometry: data.Geometry
|
||||||
end_time: float
|
|
||||||
low_freq: float
|
|
||||||
high_freq: float
|
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: xr.DataArray
|
class_scores: xr.DataArray
|
||||||
features: xr.DataArray
|
features: xr.DataArray
|
||||||
|
@ -106,7 +106,7 @@ def contains_tags(
|
|||||||
False otherwise.
|
False otherwise.
|
||||||
"""
|
"""
|
||||||
sound_event_tags = set(sound_event_annotation.tags)
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
return tags < sound_event_tags
|
return tags <= sound_event_tags
|
||||||
|
|
||||||
|
|
||||||
def does_not_have_tags(
|
def does_not_have_tags(
|
||||||
|
@ -20,14 +20,27 @@ scaling factors) is managed by the `ROIConfig`. This module separates the
|
|||||||
handled in `batdetect2.targets.classes`.
|
handled in `batdetect2.targets.classes`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, Tuple
|
from typing import List, Literal, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data, geometry
|
from soundevent import data
|
||||||
from soundevent.geometry.operations import Positions
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
|
Positions = Literal[
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"center-left",
|
||||||
|
"center-right",
|
||||||
|
"top-center",
|
||||||
|
"bottom-center",
|
||||||
|
"center",
|
||||||
|
"centroid",
|
||||||
|
"point_on_surface",
|
||||||
|
]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"ROIConfig",
|
"ROIConfig",
|
||||||
@ -242,6 +255,8 @@ class BBoxEncoder(ROITargetMapper):
|
|||||||
Tuple[float, float]
|
Tuple[float, float]
|
||||||
Reference position (time, frequency).
|
Reference position (time, frequency).
|
||||||
"""
|
"""
|
||||||
|
from soundevent import geometry
|
||||||
|
|
||||||
return geometry.get_geometry_point(geom, position=self.position)
|
return geometry.get_geometry_point(geom, position=self.position)
|
||||||
|
|
||||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||||
@ -260,6 +275,8 @@ class BBoxEncoder(ROITargetMapper):
|
|||||||
np.ndarray
|
np.ndarray
|
||||||
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
||||||
"""
|
"""
|
||||||
|
from soundevent import geometry
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
geom
|
geom
|
||||||
)
|
)
|
||||||
@ -308,8 +325,8 @@ class BBoxEncoder(ROITargetMapper):
|
|||||||
width, height = dims
|
width, height = dims
|
||||||
return _build_bounding_box(
|
return _build_bounding_box(
|
||||||
pos,
|
pos,
|
||||||
duration=width / self.time_scale,
|
duration=float(width) / self.time_scale,
|
||||||
bandwidth=height / self.frequency_scale,
|
bandwidth=float(height) / self.frequency_scale,
|
||||||
position=self.position,
|
position=self.position,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -421,14 +438,16 @@ def _build_bounding_box(
|
|||||||
ValueError
|
ValueError
|
||||||
If `position` is not a recognized value or format.
|
If `position` is not a recognized value or format.
|
||||||
"""
|
"""
|
||||||
time, freq = pos
|
time, freq = map(float, pos)
|
||||||
|
duration = max(0, duration)
|
||||||
|
bandwidth = max(0, bandwidth)
|
||||||
if position in ["center", "centroid", "point_on_surface"]:
|
if position in ["center", "centroid", "point_on_surface"]:
|
||||||
return data.BoundingBox(
|
return data.BoundingBox(
|
||||||
coordinates=[
|
coordinates=[
|
||||||
time - duration / 2,
|
max(time - duration / 2, 0),
|
||||||
freq - bandwidth / 2,
|
max(freq - bandwidth / 2, 0),
|
||||||
time + duration / 2,
|
max(time + duration / 2, 0),
|
||||||
freq + bandwidth / 2,
|
max(freq + bandwidth / 2, 0),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -454,9 +473,9 @@ def _build_bounding_box(
|
|||||||
|
|
||||||
return data.BoundingBox(
|
return data.BoundingBox(
|
||||||
coordinates=[
|
coordinates=[
|
||||||
start_time,
|
max(0, start_time),
|
||||||
low_freq,
|
max(0, low_freq),
|
||||||
start_time + duration,
|
max(0, start_time + duration),
|
||||||
low_freq + bandwidth,
|
max(0, low_freq + bandwidth),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -14,20 +14,28 @@ from batdetect2.train.augmentations import (
|
|||||||
warp_spectrogram,
|
warp_spectrogram,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper, select_subclip
|
from batdetect2.train.clips import build_clipper, select_subclip
|
||||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
from batdetect2.train.config import (
|
||||||
|
TrainerConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
load_train_config,
|
||||||
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.losses import LossFunction, build_loss
|
from batdetect2.train.losses import LossFunction, build_loss
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
generate_train_example,
|
generate_train_example,
|
||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
from batdetect2.train.train import (
|
||||||
|
build_train_dataset,
|
||||||
|
build_val_dataset,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
@ -44,13 +52,15 @@ __all__ = [
|
|||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"build_augmentations",
|
"build_augmentations",
|
||||||
|
"build_clip_labeler",
|
||||||
"build_clipper",
|
"build_clipper",
|
||||||
"build_loss",
|
"build_loss",
|
||||||
|
"build_train_dataset",
|
||||||
|
"build_val_dataset",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"list_preprocessed_files",
|
"list_preprocessed_files",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"load_trainer_config",
|
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_examples",
|
"mix_examples",
|
||||||
@ -58,5 +68,6 @@ __all__ = [
|
|||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
|
"train",
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
@ -1,30 +1,51 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.postprocess import PostprocessorProtocol
|
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
||||||
|
from batdetect2.evaluate.types import Match, MetricsProtocol
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.types import ModelOutput
|
from batdetect2.train.lightning import TrainingModule
|
||||||
|
from batdetect2.train.types import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self, postprocessor: PostprocessorProtocol):
|
def __init__(self, metrics: List[MetricsProtocol]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.postprocessor = postprocessor
|
|
||||||
self.predictions = []
|
if len(metrics) == 0:
|
||||||
|
raise ValueError("At least one metric needs to be provided")
|
||||||
|
|
||||||
|
self.matches: List[Match] = []
|
||||||
|
self.metrics = metrics
|
||||||
|
|
||||||
|
def on_validation_epoch_end(
|
||||||
|
self,
|
||||||
|
trainer: Trainer,
|
||||||
|
pl_module: LightningModule,
|
||||||
|
) -> None:
|
||||||
|
for metric in self.metrics:
|
||||||
|
value = metric(self.matches)
|
||||||
|
pl_module.log(f"val/metric/{metric.name}", value, prog_bar=True)
|
||||||
|
|
||||||
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
def on_validation_epoch_start(
|
def on_validation_epoch_start(
|
||||||
self,
|
self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.predictions = []
|
self.matches = []
|
||||||
return super().on_validation_epoch_start(trainer, pl_module)
|
return super().on_validation_epoch_start(trainer, pl_module)
|
||||||
|
|
||||||
def on_validation_batch_end( # type: ignore
|
def on_validation_batch_end( # type: ignore
|
||||||
self,
|
self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: TrainingModule,
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
batch: TrainExample,
|
batch: TrainExample,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
@ -32,24 +53,73 @@ class ValidationMetrics(Callback):
|
|||||||
) -> None:
|
) -> None:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
|
||||||
dataset = dataloaders.dataset
|
dataset = dataloaders.dataset
|
||||||
assert isinstance(dataset, LabeledDataset)
|
assert isinstance(dataset, LabeledDataset)
|
||||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
|
||||||
|
|
||||||
# clip_prediction = postprocess_model_outputs(
|
clip_annotations = [
|
||||||
# outputs,
|
_get_subclip(
|
||||||
# clips=[clip_annotation.clip],
|
dataset.get_clip_annotation(example_id),
|
||||||
# classes=self.class_names,
|
start_time=start_time.item(),
|
||||||
# decoder=self.decoder,
|
end_time=end_time.item(),
|
||||||
# config=self.config.postprocessing,
|
targets=pl_module.targets,
|
||||||
# )[0]
|
)
|
||||||
#
|
for example_id, start_time, end_time in zip(
|
||||||
# matches = match_predictions_and_annotations(
|
batch.idx,
|
||||||
# clip_annotation,
|
batch.start_time,
|
||||||
# clip_prediction,
|
batch.end_time,
|
||||||
# )
|
)
|
||||||
#
|
]
|
||||||
# self.validation_predictions.extend(matches)
|
|
||||||
# return super().on_validation_batch_end(
|
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||||
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
|
||||||
# )
|
raw_predictions = pl_module.postprocessor.get_raw_predictions(
|
||||||
|
outputs,
|
||||||
|
clips,
|
||||||
|
)
|
||||||
|
|
||||||
|
for clip_annotation, clip_predictions in zip(
|
||||||
|
clip_annotations, raw_predictions
|
||||||
|
):
|
||||||
|
self.matches.extend(
|
||||||
|
match_sound_events_and_raw_predictions(
|
||||||
|
sound_events=clip_annotation.sound_events,
|
||||||
|
raw_predictions=clip_predictions,
|
||||||
|
targets=pl_module.targets,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_in_subclip(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> bool:
|
||||||
|
time, _ = targets.get_position(sound_event_annotation)
|
||||||
|
return start_time <= time <= end_time
|
||||||
|
|
||||||
|
|
||||||
|
def _get_subclip(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
return data.ClipAnnotation(
|
||||||
|
clip=data.Clip(
|
||||||
|
recording=clip_annotation.clip.recording,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
),
|
||||||
|
sound_events=[
|
||||||
|
sound_event_annotation
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if _is_in_subclip(
|
||||||
|
sound_event_annotation,
|
||||||
|
targets,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -42,8 +42,8 @@ class LabeledDataset(Dataset):
|
|||||||
class_heatmap=self.to_tensor(dataset["class"]),
|
class_heatmap=self.to_tensor(dataset["class"]),
|
||||||
size_heatmap=self.to_tensor(dataset["size"]),
|
size_heatmap=self.to_tensor(dataset["size"]),
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
start_time=start_time,
|
start_time=torch.tensor(start_time),
|
||||||
end_time=end_time,
|
end_time=torch.tensor(end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -23,13 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
|||||||
defined in `LabelConfig`.
|
defined in `LabelConfig`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
@ -52,8 +52,6 @@ __all__ = [
|
|||||||
SIZE_DIMENSION = "dimension"
|
SIZE_DIMENSION = "dimension"
|
||||||
"""Dimension name for the size heatmap."""
|
"""Dimension name for the size heatmap."""
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LabelConfig(BaseConfig):
|
class LabelConfig(BaseConfig):
|
||||||
"""Configuration parameters for heatmap generation.
|
"""Configuration parameters for heatmap generation.
|
||||||
@ -137,12 +135,27 @@ def generate_clip_label(
|
|||||||
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
||||||
heatmaps for this clip.
|
heatmaps for this clip.
|
||||||
"""
|
"""
|
||||||
|
logger.debug(
|
||||||
|
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
||||||
|
uuid=clip_annotation.uuid,
|
||||||
|
num=len(clip_annotation.sound_events)
|
||||||
|
)
|
||||||
|
|
||||||
|
sound_events = []
|
||||||
|
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
|
if not targets.filter(sound_event_annotation):
|
||||||
|
logger.debug(
|
||||||
|
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
|
||||||
|
sound_event=sound_event_annotation,
|
||||||
|
tags=sound_event_annotation.tags,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
sound_events.append(targets.transform(sound_event_annotation))
|
||||||
|
|
||||||
return generate_heatmaps(
|
return generate_heatmaps(
|
||||||
(
|
sound_events,
|
||||||
targets.transform(sound_event_annotation)
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
|
||||||
if targets.filter(sound_event_annotation)
|
|
||||||
),
|
|
||||||
spec=spec,
|
spec=spec,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
target_sigma=config.sigma,
|
target_sigma=config.sigma,
|
||||||
|
@ -58,7 +58,9 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
return losses.total
|
return losses.total
|
||||||
|
|
||||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
def validation_step( # type: ignore
|
||||||
|
self, batch: TrainExample, batch_idx: int
|
||||||
|
) -> ModelOutput:
|
||||||
outputs = self.forward(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
@ -67,6 +69,8 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.log("val/loss/size", losses.total, logger=True)
|
self.log("val/loss/size", losses.total, logger=True)
|
||||||
self.log("val/loss/classification", losses.total, logger=True)
|
self.log("val/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
|
from lightning.pytorch.callbacks import Callback
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.models.types import DetectionModel
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
@ -19,25 +23,40 @@ from batdetect2.train.losses import build_loss
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"train",
|
"train",
|
||||||
|
"build_val_dataset",
|
||||||
|
"build_train_dataset",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
detector: DetectionModel,
|
detector: DetectionModel,
|
||||||
targets: TargetProtocol,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
train_examples: List[data.PathLike],
|
train_examples: List[data.PathLike],
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||||
val_examples: Optional[List[data.PathLike]] = None,
|
val_examples: Optional[List[data.PathLike]] = None,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
|
callbacks: Optional[List[Callback]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
train_dataset = build_dataset(
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
if targets is None:
|
||||||
|
targets = build_targets()
|
||||||
|
|
||||||
|
if postprocessor is None:
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
targets,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = build_train_dataset(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
train=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = build_loss(config.loss)
|
loss = build_loss(config.loss)
|
||||||
@ -52,7 +71,13 @@ def train(
|
|||||||
t_max=config.optimizer.t_max,
|
t_max=config.optimizer.t_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(**config.trainer.model_dump())
|
trainer = Trainer(
|
||||||
|
**config.trainer.model_dump(exclude_none=True),
|
||||||
|
callbacks=callbacks,
|
||||||
|
num_sanity_val_steps=0,
|
||||||
|
# enable_model_summary=False,
|
||||||
|
# enable_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@ -62,11 +87,9 @@ def train(
|
|||||||
|
|
||||||
val_dataloader = None
|
val_dataloader = None
|
||||||
if val_examples:
|
if val_examples:
|
||||||
val_dataset = build_dataset(
|
val_dataset = build_val_dataset(
|
||||||
val_examples,
|
val_examples,
|
||||||
preprocessor,
|
|
||||||
config=config,
|
config=config,
|
||||||
train=False,
|
|
||||||
)
|
)
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
@ -81,32 +104,38 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(
|
def build_train_dataset(
|
||||||
examples: List[data.PathLike],
|
examples: List[data.PathLike],
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
train: bool = True,
|
) -> LabeledDataset:
|
||||||
):
|
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
clipper = build_clipper(config.cliping, random=train)
|
clipper = build_clipper(config.cliping, random=True)
|
||||||
|
|
||||||
augmentations = None
|
random_example_source = RandomExampleSource(
|
||||||
|
examples,
|
||||||
|
clipper=clipper,
|
||||||
|
)
|
||||||
|
|
||||||
if train:
|
augmentations = build_augmentations(
|
||||||
random_example_source = RandomExampleSource(
|
preprocessor,
|
||||||
examples,
|
config=config.augmentations,
|
||||||
clipper=clipper,
|
example_source=random_example_source,
|
||||||
)
|
)
|
||||||
|
|
||||||
augmentations = build_augmentations(
|
|
||||||
preprocessor,
|
|
||||||
config=config.augmentations,
|
|
||||||
example_source=random_example_source,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LabeledDataset(
|
return LabeledDataset(
|
||||||
examples,
|
examples,
|
||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
augmentation=augmentations,
|
augmentation=augmentations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_val_dataset(
|
||||||
|
examples: List[data.PathLike],
|
||||||
|
config: Optional[TrainingConfig] = None,
|
||||||
|
train: bool = True,
|
||||||
|
) -> LabeledDataset:
|
||||||
|
config = config or TrainingConfig()
|
||||||
|
clipper = build_clipper(config.cliping, random=train)
|
||||||
|
return LabeledDataset(examples, clipper=clipper)
|
||||||
|
@ -57,8 +57,8 @@ class TrainExample(NamedTuple):
|
|||||||
class_heatmap: torch.Tensor
|
class_heatmap: torch.Tensor
|
||||||
size_heatmap: torch.Tensor
|
size_heatmap: torch.Tensor
|
||||||
idx: torch.Tensor
|
idx: torch.Tensor
|
||||||
start_time: float
|
start_time: torch.Tensor
|
||||||
end_time: float
|
end_time: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class Losses(NamedTuple):
|
class Losses(NamedTuple):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
|
|
||||||
def extend_width(
|
def extend_width(
|
||||||
@ -59,3 +60,10 @@ def adjust_width(
|
|||||||
for index in range(dims)
|
for index in range(dims)
|
||||||
]
|
]
|
||||||
return array[tuple(slices)]
|
return array[tuple(slices)]
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_over_array(array: xr.DataArray):
|
||||||
|
dim_name = array.dims[0]
|
||||||
|
coords = array.coords[dim_name]
|
||||||
|
for value, coord in zip(array.values, coords.values):
|
||||||
|
yield coord, float(value)
|
||||||
|
Loading…
Reference in New Issue
Block a user