mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Run formatter
This commit is contained in:
parent
531ff69974
commit
0adb58e039
@ -32,11 +32,11 @@ from batdetect2.train import (
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
EvaluatorProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -110,7 +110,7 @@ class BatDetect2API:
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
save_predictions: bool = True,
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
||||
return evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
@ -128,7 +128,7 @@ class BatDetect2API:
|
||||
def evaluate_predictions(
|
||||
self,
|
||||
annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
output_dir: data.PathLike | None = None,
|
||||
):
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
@ -170,24 +170,24 @@ class BatDetect2API:
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
||||
def process_file(self, audio_file: str) -> ClipDetections:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
wav = self.audio_loader.load_recording(recording)
|
||||
detections = self.process_audio(wav)
|
||||
return BatDetect2Prediction(
|
||||
return ClipDetections(
|
||||
clip=data.Clip(
|
||||
uuid=recording.uuid,
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
),
|
||||
predictions=detections,
|
||||
detections=detections,
|
||||
)
|
||||
|
||||
def process_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> List[RawPrediction]:
|
||||
) -> List[Detection]:
|
||||
spec = self.generate_spectrogram(audio)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
@ -195,7 +195,7 @@ class BatDetect2API:
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
start_time: float = 0,
|
||||
) -> List[RawPrediction]:
|
||||
) -> List[Detection]:
|
||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||
raise ValueError("Batched spectrograms not supported.")
|
||||
|
||||
@ -214,7 +214,7 @@ class BatDetect2API:
|
||||
def process_directory(
|
||||
self,
|
||||
audio_dir: data.PathLike,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
) -> List[ClipDetections]:
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(files)
|
||||
|
||||
@ -222,7 +222,7 @@ class BatDetect2API:
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
num_workers: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
) -> List[ClipDetections]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
@ -238,7 +238,7 @@ class BatDetect2API:
|
||||
clips: Sequence[data.Clip],
|
||||
batch_size: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
) -> List[ClipDetections]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
@ -252,7 +252,7 @@ class BatDetect2API:
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
path: data.PathLike,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
format: str | None = None,
|
||||
@ -274,7 +274,7 @@ class BatDetect2API:
|
||||
def load_predictions(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
) -> List[ClipDetections]:
|
||||
return self.formatter.load(path)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
|
||||
@ -264,7 +264,14 @@ class Not:
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig,
|
||||
HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| DurationConfig
|
||||
| FrequencyConfig
|
||||
| AllOfConfig
|
||||
| AnyOfConfig
|
||||
| NotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -13,9 +13,9 @@ from batdetect2.data.predictions.base import (
|
||||
)
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
OutputFormatterProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -113,7 +113,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
self.annotation_note = annotation_note
|
||||
|
||||
def format(
|
||||
self, predictions: Sequence[BatDetect2Prediction]
|
||||
self, predictions: Sequence[ClipDetections]
|
||||
) -> List[FileAnnotation]:
|
||||
return [
|
||||
self.format_prediction(prediction) for prediction in predictions
|
||||
@ -164,14 +164,12 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
||||
return highest_scoring["class"]
|
||||
|
||||
def format_prediction(
|
||||
self, prediction: BatDetect2Prediction
|
||||
) -> FileAnnotation:
|
||||
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
|
||||
recording = prediction.clip.recording
|
||||
|
||||
annotations = [
|
||||
self.format_sound_event_prediction(pred)
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
]
|
||||
|
||||
return FileAnnotation(
|
||||
@ -196,7 +194,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
) # type: ignore
|
||||
|
||||
def format_sound_event_prediction(
|
||||
self, prediction: RawPrediction
|
||||
self, prediction: Detection
|
||||
) -> Annotation:
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
prediction.geometry
|
||||
|
||||
@ -14,9 +14,9 @@ from batdetect2.data.predictions.base import (
|
||||
prediction_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
OutputFormatterProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ class ParquetOutputConfig(BaseConfig):
|
||||
include_geometry: bool = True
|
||||
|
||||
|
||||
class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
@ -44,13 +44,13 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
|
||||
def format(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
) -> List[BatDetect2Prediction]:
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> List[ClipDetections]:
|
||||
return list(predictions)
|
||||
|
||||
def save(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
path: data.PathLike,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> None:
|
||||
@ -61,10 +61,10 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
|
||||
# Ensure the file has .parquet extension if it's a file path
|
||||
if path.suffix != ".parquet":
|
||||
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
|
||||
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
|
||||
if path.is_dir() or not path.suffix:
|
||||
path = path / "predictions.parquet"
|
||||
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
|
||||
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
|
||||
if path.is_dir() or not path.suffix:
|
||||
path = path / "predictions.parquet"
|
||||
|
||||
rows = []
|
||||
for prediction in predictions:
|
||||
@ -73,12 +73,14 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
|
||||
if audio_dir is not None:
|
||||
recording = recording.model_copy(
|
||||
update=dict(path=make_path_relative(recording.path, audio_dir))
|
||||
update=dict(
|
||||
path=make_path_relative(recording.path, audio_dir)
|
||||
)
|
||||
)
|
||||
|
||||
recording_json = recording.model_dump_json(exclude_none=True)
|
||||
|
||||
for pred in prediction.predictions:
|
||||
for pred in prediction.detections:
|
||||
row = {
|
||||
"clip_uuid": str(clip.uuid),
|
||||
"clip_start_time": clip.start_time,
|
||||
@ -116,16 +118,16 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
logger.info(f"Saving {len(df)} predictions to {path}")
|
||||
df.to_parquet(path, index=False)
|
||||
|
||||
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
# Try to find parquet files
|
||||
files = list(path.glob("*.parquet"))
|
||||
if not files:
|
||||
return []
|
||||
# Read all and concatenate
|
||||
dfs = [pd.read_parquet(f) for f in files]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
# Try to find parquet files
|
||||
files = list(path.glob("*.parquet"))
|
||||
if not files:
|
||||
return []
|
||||
# Read all and concatenate
|
||||
dfs = [pd.read_parquet(f) for f in files]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
else:
|
||||
df = pd.read_parquet(path)
|
||||
|
||||
@ -135,17 +137,16 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
clip_uuid = row["clip_uuid"]
|
||||
|
||||
if clip_uuid not in predictions_by_clip:
|
||||
recording = data.Recording.model_validate_json(row["recording_info"])
|
||||
recording = data.Recording.model_validate_json(
|
||||
row["recording_info"]
|
||||
)
|
||||
clip = data.Clip(
|
||||
uuid=UUID(clip_uuid),
|
||||
recording=recording,
|
||||
start_time=row["clip_start_time"],
|
||||
end_time=row["clip_end_time"],
|
||||
)
|
||||
predictions_by_clip[clip_uuid] = {
|
||||
"clip": clip,
|
||||
"preds": []
|
||||
}
|
||||
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
|
||||
|
||||
# Reconstruct geometry
|
||||
if "geometry" in row and row["geometry"]:
|
||||
@ -156,14 +157,20 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
row["start_time"],
|
||||
row["low_freq"],
|
||||
row["end_time"],
|
||||
row["high_freq"]
|
||||
row["high_freq"],
|
||||
]
|
||||
)
|
||||
|
||||
class_scores = np.array(row["class_scores"]) if "class_scores" in row else np.zeros(len(self.targets.class_names))
|
||||
features = np.array(row["features"]) if "features" in row else np.zeros(0)
|
||||
class_scores = (
|
||||
np.array(row["class_scores"])
|
||||
if "class_scores" in row
|
||||
else np.zeros(len(self.targets.class_names))
|
||||
)
|
||||
features = (
|
||||
np.array(row["features"]) if "features" in row else np.zeros(0)
|
||||
)
|
||||
|
||||
pred = RawPrediction(
|
||||
pred = Detection(
|
||||
geometry=geometry,
|
||||
detection_score=row["detection_score"],
|
||||
class_scores=class_scores,
|
||||
@ -174,9 +181,8 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
results = []
|
||||
for clip_data in predictions_by_clip.values():
|
||||
results.append(
|
||||
BatDetect2Prediction(
|
||||
clip=clip_data["clip"],
|
||||
predictions=clip_data["preds"]
|
||||
ClipDetections(
|
||||
clip=clip_data["clip"], detections=clip_data["preds"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -15,9 +15,9 @@ from batdetect2.data.predictions.base import (
|
||||
prediction_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
OutputFormatterProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -30,7 +30,7 @@ class RawOutputConfig(BaseConfig):
|
||||
include_geometry: bool = True
|
||||
|
||||
|
||||
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
@ -47,13 +47,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
|
||||
def format(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
) -> List[BatDetect2Prediction]:
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> List[ClipDetections]:
|
||||
return list(predictions)
|
||||
|
||||
def save(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
path: data.PathLike,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> None:
|
||||
@ -68,10 +68,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
dataset = self.pred_to_xr(prediction, audio_dir)
|
||||
dataset.to_netcdf(path / f"{clip.uuid}.nc")
|
||||
|
||||
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
||||
path = Path(path)
|
||||
files = list(path.glob("*.nc"))
|
||||
predictions: List[BatDetect2Prediction] = []
|
||||
predictions: List[ClipDetections] = []
|
||||
|
||||
for filepath in files:
|
||||
logger.debug(f"Loading clip predictions {filepath}")
|
||||
@ -83,7 +83,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
|
||||
def pred_to_xr(
|
||||
self,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> xr.Dataset:
|
||||
clip = prediction.clip
|
||||
@ -97,7 +97,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
|
||||
data = defaultdict(list)
|
||||
|
||||
for pred in prediction.predictions:
|
||||
for pred in prediction.detections:
|
||||
detection_id = str(uuid4())
|
||||
|
||||
data["detection_id"].append(detection_id)
|
||||
@ -167,7 +167,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
},
|
||||
)
|
||||
|
||||
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
|
||||
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
|
||||
clip_data = dataset
|
||||
clip_id = clip_data.clip_id.item()
|
||||
|
||||
@ -219,7 +219,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
features = np.zeros(0)
|
||||
|
||||
sound_events.append(
|
||||
RawPrediction(
|
||||
Detection(
|
||||
geometry=geometry,
|
||||
detection_score=score,
|
||||
class_scores=class_scores,
|
||||
@ -227,9 +227,9 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
)
|
||||
)
|
||||
|
||||
return BatDetect2Prediction(
|
||||
return ClipDetections(
|
||||
clip=clip,
|
||||
predictions=sound_events,
|
||||
detections=sound_events,
|
||||
)
|
||||
|
||||
@prediction_formatters.register(RawOutputConfig)
|
||||
|
||||
@ -9,9 +9,9 @@ from batdetect2.data.predictions.base import (
|
||||
prediction_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
OutputFormatterProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -35,7 +35,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||
|
||||
def format(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> List[data.ClipPrediction]:
|
||||
return [
|
||||
self.format_prediction(prediction) for prediction in predictions
|
||||
@ -63,20 +63,20 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||
|
||||
def format_prediction(
|
||||
self,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> data.ClipPrediction:
|
||||
recording = prediction.clip.recording
|
||||
return data.ClipPrediction(
|
||||
clip=prediction.clip,
|
||||
sound_events=[
|
||||
self.format_sound_event_prediction(pred, recording)
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
],
|
||||
)
|
||||
|
||||
def format_sound_event_prediction(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
prediction: Detection,
|
||||
recording: data.Recording,
|
||||
) -> data.SoundEventPrediction:
|
||||
return data.SoundEventPrediction(
|
||||
@ -89,7 +89,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||
)
|
||||
|
||||
def get_sound_event_tags(
|
||||
self, prediction: RawPrediction
|
||||
self, prediction: Detection
|
||||
) -> List[data.PredictedTag]:
|
||||
sorted_indices = np.argsort(prediction.class_scores)[::-1]
|
||||
|
||||
|
||||
@ -221,7 +221,11 @@ class ApplyAll:
|
||||
|
||||
|
||||
SoundEventTransformConfig = Annotated[
|
||||
SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig,
|
||||
SetFrequencyBoundConfig
|
||||
| ReplaceTagConfig
|
||||
| MapTagValueConfig
|
||||
| ApplyIfConfig
|
||||
| ApplyAllConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from soundevent.geometry import (
|
||||
)
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import AffinityFunction, RawPrediction
|
||||
from batdetect2.typing import AffinityFunction, Detection
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
"affinity_function"
|
||||
@ -42,7 +42,7 @@ class TimeAffinity(AffinityFunction):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
detection: RawPrediction,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float:
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
@ -77,7 +77,7 @@ class IntervalIOU(AffinityFunction):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
detection: RawPrediction,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float:
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
@ -120,7 +120,7 @@ class BBoxIOU(AffinityFunction):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
prediction: Detection,
|
||||
gt: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = gt.sound_event.geometry
|
||||
@ -168,7 +168,7 @@ class GeometricIOU(AffinityFunction):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
prediction: Detection,
|
||||
gt: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = gt.sound_event.geometry
|
||||
|
||||
@ -12,7 +12,7 @@ from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing import Detection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
@ -38,7 +38,7 @@ def evaluate(
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
@ -7,7 +7,7 @@ from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
@ -27,7 +27,7 @@ class Evaluator:
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> List[Any]:
|
||||
return [
|
||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||
|
||||
@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
@ -24,7 +24,7 @@ class EvaluationModule(LightningModule):
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[BatDetect2Prediction] = []
|
||||
self.predictions: List[ClipDetections] = []
|
||||
|
||||
def test_step(self, batch: TestExample, batch_idx: int):
|
||||
dataset = self.get_dataset()
|
||||
@ -39,9 +39,9 @@ class EvaluationModule(LightningModule):
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
BatDetect2Prediction(
|
||||
ClipDetections(
|
||||
clip=clip_annotation.clip,
|
||||
predictions=to_raw_predictions(
|
||||
detections=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
),
|
||||
|
||||
@ -21,7 +21,7 @@ from batdetect2.evaluate.metrics.common import (
|
||||
average_precision,
|
||||
compute_precision_recall,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import Detection, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassificationMetric",
|
||||
@ -35,7 +35,7 @@ __all__ = [
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
pred: Detection | None
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
|
||||
@ -159,7 +159,10 @@ class ClipDetectionPrecision:
|
||||
|
||||
|
||||
ClipDetectionMetricConfig = Annotated[
|
||||
ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig,
|
||||
ClipDetectionAveragePrecisionConfig
|
||||
| ClipDetectionROCAUCConfig
|
||||
| ClipDetectionRecallConfig
|
||||
| ClipDetectionPrecisionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
from batdetect2.typing import Detection
|
||||
|
||||
__all__ = [
|
||||
"DetectionMetricConfig",
|
||||
@ -27,7 +27,7 @@ __all__ = [
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
pred: Detection | None
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
|
||||
@ -15,7 +15,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
from batdetect2.typing import Detection
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -29,7 +29,7 @@ __all__ = [
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
pred: Detection | None
|
||||
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
@ -298,7 +298,11 @@ class BalancedAccuracy:
|
||||
|
||||
|
||||
TopClassMetricConfig = Annotated[
|
||||
TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig,
|
||||
TopClassAveragePrecisionConfig
|
||||
| TopClassROCAUCConfig
|
||||
| TopClassRecallConfig
|
||||
| TopClassPrecisionConfig
|
||||
| BalancedAccuracyConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
|
||||
@ -322,7 +322,10 @@ class ROCCurve(BasePlot):
|
||||
|
||||
|
||||
ClassificationPlotConfig = Annotated[
|
||||
PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig,
|
||||
PRCurveConfig
|
||||
| ROCCurveConfig
|
||||
| ThresholdPrecisionCurveConfig
|
||||
| ThresholdRecallCurveConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -290,7 +290,10 @@ class ExampleDetectionPlot(BasePlot):
|
||||
|
||||
|
||||
DetectionPlotConfig = Annotated[
|
||||
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig,
|
||||
PRCurveConfig
|
||||
| ROCCurveConfig
|
||||
| ScoreDistributionPlotConfig
|
||||
| ExampleDetectionPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -346,7 +346,10 @@ class ExampleClassificationPlot(BasePlot):
|
||||
|
||||
|
||||
TopClassPlotConfig = Annotated[
|
||||
PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig,
|
||||
PRCurveConfig
|
||||
| ROCCurveConfig
|
||||
| ConfusionMatrixConfig
|
||||
| ExampleClassificationPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -403,7 +406,8 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[(index, match.score) for index, match in enumerate(matches)], strict=False
|
||||
*[(index, match.score) for index, match in enumerate(matches)],
|
||||
strict=False,
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
|
||||
@ -13,7 +13,7 @@ from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
EvaluatorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
@ -45,7 +45,7 @@ def build_task(
|
||||
|
||||
def evaluate_task(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
task: Optional["str"] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: TaskConfig | dict | None = None,
|
||||
|
||||
@ -22,9 +22,9 @@ from batdetect2.evaluate.affinity import (
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
AffinityFunction,
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
EvaluatorProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -96,7 +96,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> List[T_Output]:
|
||||
return [
|
||||
self.evaluate_clip(clip_annotation, preds)
|
||||
@ -108,7 +108,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> T_Output: ...
|
||||
|
||||
def include_sound_event_annotation(
|
||||
@ -128,7 +128,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
|
||||
def include_prediction(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
prediction: Detection,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
return is_in_bounds(
|
||||
|
||||
@ -22,8 +22,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
RawPrediction,
|
||||
ClipDetections,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -51,13 +51,13 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
preds = [
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
|
||||
@ -140,7 +140,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
||||
)
|
||||
|
||||
|
||||
def get_class_score(pred: RawPrediction, class_idx: int) -> float:
|
||||
def get_class_score(pred: Detection, class_idx: int) -> float:
|
||||
return pred.class_scores[class_idx]
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
@ -37,7 +37,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -54,7 +54,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
gt_classes.add(class_name)
|
||||
|
||||
pred_scores = defaultdict(float)
|
||||
for pred in prediction.predictions:
|
||||
for pred in prediction.detections:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
@ -36,7 +36,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -46,7 +46,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
)
|
||||
|
||||
pred_score = 0
|
||||
for pred in prediction.predictions:
|
||||
for pred in prediction.detections:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from batdetect2.evaluate.tasks.base import (
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||
@ -37,7 +37,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -48,7 +48,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
||||
]
|
||||
preds = [
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||
@ -36,7 +36,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -47,7 +47,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
||||
]
|
||||
preds = [
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
|
||||
|
||||
@ -88,7 +88,8 @@ def select_device(warn=True) -> str:
|
||||
if warn:
|
||||
warnings.warn(
|
||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||
"to speed up training.", stacklevel=2
|
||||
"to speed up training.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return "cpu"
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
@ -30,7 +30,7 @@ def run_batch_inference(
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
) -> List[ClipDetections]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
@ -70,7 +70,7 @@ def process_file_list(
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
num_workers: int | None = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
) -> List[ClipDetections]:
|
||||
clip_config = config.inference.clipping
|
||||
clips = get_clips_from_files(
|
||||
paths,
|
||||
|
||||
@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
|
||||
|
||||
class InferenceModule(LightningModule):
|
||||
@ -19,7 +19,7 @@ class InferenceModule(LightningModule):
|
||||
batch: DatasetItem,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> Sequence[BatDetect2Prediction]:
|
||||
) -> Sequence[ClipDetections]:
|
||||
dataset = self.get_dataset()
|
||||
|
||||
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
||||
@ -32,9 +32,9 @@ class InferenceModule(LightningModule):
|
||||
)
|
||||
|
||||
predictions = [
|
||||
BatDetect2Prediction(
|
||||
ClipDetections(
|
||||
clip=clip,
|
||||
predictions=to_raw_predictions(
|
||||
detections=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.model.targets,
|
||||
),
|
||||
|
||||
@ -76,7 +76,10 @@ class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||
|
||||
|
||||
LoggerConfig = Annotated[
|
||||
DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig,
|
||||
DVCLiveConfig
|
||||
| CSVLoggerConfig
|
||||
| TensorBoardLoggerConfig
|
||||
| MLFlowLoggerConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
|
||||
@ -41,7 +41,10 @@ __all__ = [
|
||||
]
|
||||
|
||||
DecoderLayerConfig = Annotated[
|
||||
ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig,
|
||||
ConvConfig
|
||||
| FreqCoordConvUpConfig
|
||||
| StandardConvUpConfig
|
||||
| LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
|
||||
@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@ -43,7 +43,10 @@ __all__ = [
|
||||
]
|
||||
|
||||
EncoderLayerConfig = Annotated[
|
||||
ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig,
|
||||
ConvConfig
|
||||
| FreqCoordConvDownConfig
|
||||
| StandardConvDownConfig
|
||||
| LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from matplotlib import axes, patches
|
||||
from soundevent.plot import plot_geometry
|
||||
|
||||
|
||||
@ -36,7 +36,9 @@ def plot_match_gallery(
|
||||
sharey="row",
|
||||
)
|
||||
|
||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False):
|
||||
for tp_ax, tp_match in zip(
|
||||
axes[0], true_positives[:n_examples], strict=False
|
||||
):
|
||||
try:
|
||||
plot_true_positive_match(
|
||||
tp_match,
|
||||
@ -53,7 +55,9 @@ def plot_match_gallery(
|
||||
):
|
||||
continue
|
||||
|
||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False):
|
||||
for fp_ax, fp_match in zip(
|
||||
axes[1], false_positives[:n_examples], strict=False
|
||||
):
|
||||
try:
|
||||
plot_false_positive_match(
|
||||
fp_match,
|
||||
@ -70,7 +74,9 @@ def plot_match_gallery(
|
||||
):
|
||||
continue
|
||||
|
||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False):
|
||||
for fn_ax, fn_match in zip(
|
||||
axes[2], false_negatives[:n_examples], strict=False
|
||||
):
|
||||
try:
|
||||
plot_false_negative_match(
|
||||
fn_match,
|
||||
@ -87,7 +93,9 @@ def plot_match_gallery(
|
||||
):
|
||||
continue
|
||||
|
||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False):
|
||||
for ct_ax, ct_match in zip(
|
||||
axes[3], cross_triggers[:n_examples], strict=False
|
||||
):
|
||||
try:
|
||||
plot_cross_trigger_match(
|
||||
ct_match,
|
||||
|
||||
@ -8,7 +8,7 @@ from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -22,7 +22,7 @@ __all__ = [
|
||||
class MatchProtocol(Protocol):
|
||||
clip: data.Clip
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: RawPrediction | None
|
||||
pred: Detection | None
|
||||
score: float
|
||||
true_class: str | None
|
||||
|
||||
@ -341,4 +341,3 @@ def plot_cross_trigger_match(
|
||||
ax.set_title("Cross Trigger")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
6
src/batdetect2/postprocess/clips.py
Normal file
6
src/batdetect2/postprocess/clips.py
Normal file
@ -0,0 +1,6 @@
|
||||
from batdetect2.typing import ClipDetections
|
||||
|
||||
|
||||
class ClipTransform:
|
||||
def __init__(self, clip: ClipDetections):
|
||||
pass
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import (
|
||||
ClipDetectionsArray,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
@ -30,7 +30,7 @@ decoding.
|
||||
def to_raw_predictions(
|
||||
detections: ClipDetectionsArray,
|
||||
targets: TargetProtocol,
|
||||
) -> List[RawPrediction]:
|
||||
) -> List[Detection]:
|
||||
predictions = []
|
||||
|
||||
for score, class_scores, time, freq, dims, feats in zip(
|
||||
@ -39,7 +39,8 @@ def to_raw_predictions(
|
||||
detections.times,
|
||||
detections.frequencies,
|
||||
detections.sizes,
|
||||
detections.features, strict=False,
|
||||
detections.features,
|
||||
strict=False,
|
||||
):
|
||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||
|
||||
@ -50,7 +51,7 @@ def to_raw_predictions(
|
||||
)
|
||||
|
||||
predictions.append(
|
||||
RawPrediction(
|
||||
Detection(
|
||||
detection_score=score,
|
||||
geometry=geom,
|
||||
class_scores=class_scores,
|
||||
@ -62,7 +63,7 @@ def to_raw_predictions(
|
||||
|
||||
|
||||
def convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions: List[RawPrediction],
|
||||
raw_predictions: List[Detection],
|
||||
clip: data.Clip,
|
||||
targets: TargetProtocol,
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
@ -85,7 +86,7 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
|
||||
|
||||
def convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction: RawPrediction,
|
||||
raw_prediction: Detection,
|
||||
recording: data.Recording,
|
||||
targets: TargetProtocol,
|
||||
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@ -377,7 +377,10 @@ class PeakNormalize(torch.nn.Module):
|
||||
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig,
|
||||
PcenConfig
|
||||
| ScaleAmplitudeConfig
|
||||
| SpectralMeanSubstractionConfig
|
||||
| PeakNormalizeConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -394,7 +394,8 @@ class MaskTime(torch.nn.Module):
|
||||
size=num_masks,
|
||||
)
|
||||
masks = [
|
||||
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
||||
(start, start + size)
|
||||
for start, size in zip(mask_start, mask_size, strict=False)
|
||||
]
|
||||
return mask_time(spec, masks), clip_annotation
|
||||
|
||||
@ -460,7 +461,8 @@ class MaskFrequency(torch.nn.Module):
|
||||
size=num_masks,
|
||||
)
|
||||
masks = [
|
||||
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
||||
(start, start + size)
|
||||
for start, size in zip(mask_start, mask_size, strict=False)
|
||||
]
|
||||
return mask_frequency(spec, masks), clip_annotation
|
||||
|
||||
@ -498,7 +500,12 @@ SpectrogramAugmentationConfig = Annotated[
|
||||
]
|
||||
|
||||
AugmentationConfig = Annotated[
|
||||
MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
|
||||
MixAudioConfig
|
||||
| AddEchoConfig
|
||||
| ScaleVolumeConfig
|
||||
| WarpConfig
|
||||
| MaskFrequencyConfig
|
||||
| MaskTimeConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of individual augmentation config."""
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
EvaluatorProtocol,
|
||||
ModelOutput,
|
||||
TrainExample,
|
||||
@ -24,7 +24,7 @@ class ValidationMetrics(Callback):
|
||||
self.evaluator = evaluator
|
||||
|
||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||
self._predictions: List[BatDetect2Prediction] = []
|
||||
self._predictions: List[ClipDetections] = []
|
||||
|
||||
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
@ -100,9 +100,9 @@ class ValidationMetrics(Callback):
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
BatDetect2Prediction(
|
||||
ClipDetections(
|
||||
clip=clip_annotation.clip,
|
||||
predictions=to_raw_predictions(
|
||||
detections=to_raw_predictions(
|
||||
clip_dets.numpy(), targets=model.targets
|
||||
),
|
||||
)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ The primary entry points are:
|
||||
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -10,11 +10,11 @@ from batdetect2.typing.evaluate import (
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetections,
|
||||
ClipDetectionsTensor,
|
||||
GeometryDecoder,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
@ -44,7 +44,7 @@ __all__ = [
|
||||
"AudioLoader",
|
||||
"Augmentation",
|
||||
"BackboneModel",
|
||||
"BatDetect2Prediction",
|
||||
"ClipDetections",
|
||||
"ClipDetectionsTensor",
|
||||
"ClipLabeller",
|
||||
"ClipMatches",
|
||||
@ -65,7 +65,7 @@ __all__ = [
|
||||
"PostprocessorProtocol",
|
||||
"PreprocessorProtocol",
|
||||
"ROITargetMapper",
|
||||
"RawPrediction",
|
||||
"Detection",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Generic, List, Protocol, Sequence, TypeVar
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
|
||||
__all__ = [
|
||||
"OutputFormatterProtocol",
|
||||
@ -12,9 +12,7 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||
def format(
|
||||
self, predictions: Sequence[BatDetect2Prediction]
|
||||
) -> List[T]: ...
|
||||
def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ...
|
||||
|
||||
def save(
|
||||
self,
|
||||
|
||||
@ -13,7 +13,7 @@ from typing import (
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||
from batdetect2.typing.postprocess import ClipDetections, Detection
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -84,7 +84,7 @@ Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
||||
class AffinityFunction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
detection: RawPrediction,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float: ...
|
||||
|
||||
@ -93,7 +93,7 @@ class MetricsProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
predictions: Sequence[Sequence[Detection]],
|
||||
) -> Dict[str, float]: ...
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ class PlotterProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
predictions: Sequence[Sequence[Detection]],
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
@ -114,7 +114,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> EvaluationOutput: ...
|
||||
|
||||
def compute_metrics(
|
||||
|
||||
@ -22,7 +22,7 @@ from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing.targets import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"RawPrediction",
|
||||
"Detection",
|
||||
"PostprocessorProtocol",
|
||||
"GeometryDecoder",
|
||||
]
|
||||
@ -46,7 +46,8 @@ class GeometryDecoder(Protocol):
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
@dataclass
|
||||
class Detection:
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
@ -82,9 +83,16 @@ class ClipDetectionsTensor(NamedTuple):
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatDetect2Prediction:
|
||||
class ClipDetections:
|
||||
clip: data.Clip
|
||||
predictions: List[RawPrediction]
|
||||
detections: List[Detection]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipPrediction:
|
||||
clip: data.Clip
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
|
||||
@ -220,7 +220,8 @@ def get_annotations_from_preds(
|
||||
predictions["high_freqs"],
|
||||
class_ind_best,
|
||||
class_prob_best,
|
||||
predictions["det_probs"], strict=False,
|
||||
predictions["det_probs"],
|
||||
strict=False,
|
||||
)
|
||||
]
|
||||
return annotations
|
||||
|
||||
@ -87,9 +87,7 @@ def save_ann_spec(
|
||||
y_extent = [0, duration, min_freq, max_freq]
|
||||
|
||||
plt.close("all")
|
||||
plt.figure(
|
||||
0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
|
||||
)
|
||||
plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100)
|
||||
plt.imshow(
|
||||
spec,
|
||||
aspect="auto",
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
@ -10,8 +10,8 @@ from batdetect2.data.predictions import (
|
||||
build_output_formatter,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
RawPrediction,
|
||||
ClipDetections,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -31,7 +31,7 @@ def test_roundtrip(
|
||||
tmp_path: Path,
|
||||
):
|
||||
detections = [
|
||||
RawPrediction(
|
||||
Detection(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
@ -44,7 +44,7 @@ def test_roundtrip(
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||
prediction = ClipDetections(clip=clip, detections=detections)
|
||||
|
||||
path = tmp_path / "predictions.parquet"
|
||||
|
||||
@ -58,7 +58,7 @@ def test_roundtrip(
|
||||
assert recovered[0].clip == prediction.clip
|
||||
|
||||
for recovered_prediction, detection in zip(
|
||||
recovered[0].predictions, detections
|
||||
recovered[0].detections, detections, strict=True
|
||||
):
|
||||
assert (
|
||||
recovered_prediction.detection_score == detection.detection_score
|
||||
@ -81,7 +81,7 @@ def test_multiple_clips(
|
||||
clip2 = clip.model_copy(update={"uuid": uuid4()})
|
||||
|
||||
detections1 = [
|
||||
RawPrediction(
|
||||
Detection(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
@ -94,7 +94,7 @@ def test_multiple_clips(
|
||||
]
|
||||
|
||||
detections2 = [
|
||||
RawPrediction(
|
||||
Detection(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
@ -107,8 +107,8 @@ def test_multiple_clips(
|
||||
]
|
||||
|
||||
predictions = [
|
||||
BatDetect2Prediction(clip=clip, predictions=detections1),
|
||||
BatDetect2Prediction(clip=clip2, predictions=detections2),
|
||||
ClipDetections(clip=clip, detections=detections1),
|
||||
ClipDetections(clip=clip2, detections=detections2),
|
||||
]
|
||||
|
||||
path = tmp_path / "multi_predictions.parquet"
|
||||
@ -133,16 +133,18 @@ def test_complex_geometry(
|
||||
):
|
||||
# Create a polygon geometry
|
||||
polygon = data.Polygon(
|
||||
coordinates=[[
|
||||
[0.0, 10000.0],
|
||||
[0.1, 20000.0],
|
||||
[0.2, 10000.0],
|
||||
[0.0, 10000.0],
|
||||
]]
|
||||
coordinates=[
|
||||
[
|
||||
[0.0, 10000.0],
|
||||
[0.1, 20000.0],
|
||||
[0.2, 10000.0],
|
||||
[0.0, 10000.0],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
detections = [
|
||||
RawPrediction(
|
||||
Detection(
|
||||
geometry=polygon,
|
||||
detection_score=0.95,
|
||||
class_scores=np.random.uniform(
|
||||
@ -152,7 +154,7 @@ def test_complex_geometry(
|
||||
)
|
||||
]
|
||||
|
||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||
prediction = ClipDetections(clip=clip, detections=detections)
|
||||
|
||||
path = tmp_path / "complex_geometry.parquet"
|
||||
sample_formatter.save(predictions=[prediction], path=path)
|
||||
@ -160,9 +162,9 @@ def test_complex_geometry(
|
||||
recovered = sample_formatter.load(path=path)
|
||||
|
||||
assert len(recovered) == 1
|
||||
assert len(recovered[0].predictions) == 1
|
||||
assert len(recovered[0].detections) == 1
|
||||
|
||||
recovered_pred = recovered[0].predictions[0]
|
||||
recovered_pred = recovered[0].detections[0]
|
||||
|
||||
# Check if geometry is recovered correctly as a Polygon
|
||||
assert isinstance(recovered_pred.geometry, data.Polygon)
|
||||
|
||||
@ -6,8 +6,8 @@ from soundevent import data
|
||||
|
||||
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
RawPrediction,
|
||||
ClipDetections,
|
||||
Detection,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
@ -27,7 +27,7 @@ def test_roundtrip(
|
||||
tmp_path: Path,
|
||||
):
|
||||
detections = [
|
||||
RawPrediction(
|
||||
Detection(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
@ -40,7 +40,7 @@ def test_roundtrip(
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||
prediction = ClipDetections(clip=clip, detections=detections)
|
||||
|
||||
path = tmp_path / "predictions"
|
||||
|
||||
@ -52,7 +52,9 @@ def test_roundtrip(
|
||||
assert recovered[0].clip == prediction.clip
|
||||
|
||||
for recovered_prediction, detection in zip(
|
||||
recovered[0].predictions, detections
|
||||
recovered[0].detections,
|
||||
detections,
|
||||
strict=True,
|
||||
):
|
||||
assert (
|
||||
recovered_prediction.detection_score == detection.detection_score
|
||||
|
||||
0
tests/test_evaluate/__init__.py
Normal file
0
tests/test_evaluate/__init__.py
Normal file
0
tests/test_evaluate/test_tasks/__init__.py
Normal file
0
tests/test_evaluate/test_tasks/__init__.py
Normal file
86
tests/test_evaluate/test_tasks/conftest.py
Normal file
86
tests/test_evaluate/test_tasks/conftest.py
Normal file
@ -0,0 +1,86 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing import Detection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clip(recording: data.Recording) -> data.Clip:
|
||||
return data.Clip(recording=recording, start_time=0, end_time=100)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_detection():
|
||||
def factory(
|
||||
detection_score: float = 0.5,
|
||||
start_time: float = 0.1,
|
||||
duration: float = 0.01,
|
||||
low_freq: float = 40_000,
|
||||
bandwidth: float = 30_000,
|
||||
pip_score: float = 0,
|
||||
myo_score: float = 0,
|
||||
):
|
||||
return Detection(
|
||||
detection_score=detection_score,
|
||||
class_scores=np.array(
|
||||
[
|
||||
pip_score,
|
||||
myo_score,
|
||||
]
|
||||
),
|
||||
features=np.zeros([32]),
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
start_time,
|
||||
low_freq,
|
||||
start_time + duration,
|
||||
low_freq + bandwidth,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_annotation(
|
||||
clip: data.Clip,
|
||||
bat_tag: data.Tag,
|
||||
myomyo_tag: data.Tag,
|
||||
pippip_tag: data.Tag,
|
||||
):
|
||||
def factory(
|
||||
start_time: float = 0.1,
|
||||
duration: float = 0.01,
|
||||
low_freq: float = 40_000,
|
||||
bandwidth: float = 30_000,
|
||||
is_target: bool = True,
|
||||
class_name: Literal["pippip", "myomyo"] | None = None,
|
||||
):
|
||||
tags = [bat_tag] if is_target else []
|
||||
|
||||
if class_name is not None:
|
||||
if class_name == "pippip":
|
||||
tags.append(pippip_tag)
|
||||
elif class_name == "myomyo":
|
||||
tags.append(myomyo_tag)
|
||||
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
start_time,
|
||||
low_freq,
|
||||
start_time + duration,
|
||||
low_freq + bandwidth,
|
||||
]
|
||||
),
|
||||
recording=clip.recording,
|
||||
),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
return factory
|
||||
77
tests/test_evaluate/test_tasks/test_classification.py
Normal file
77
tests/test_evaluate/test_tasks/test_classification.py
Normal file
@ -0,0 +1,77 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.typing import ClipDetections
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
def test_classification(
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
create_detection,
|
||||
create_annotation,
|
||||
):
|
||||
config = ClassificationTaskConfig.model_validate(
|
||||
{
|
||||
"name": "sound_event_classification",
|
||||
"metrics": [{"name": "average_precision"}],
|
||||
}
|
||||
)
|
||||
|
||||
evaluator = build_task(config, targets=sample_targets)
|
||||
|
||||
# Create a dummy prediction
|
||||
prediction = ClipDetections(
|
||||
clip=clip,
|
||||
detections=[
|
||||
create_detection(
|
||||
start_time=1 + 0.1 * index,
|
||||
pip_score=score,
|
||||
)
|
||||
for index, score in enumerate(np.linspace(0, 1, 100))
|
||||
]
|
||||
+ [
|
||||
create_detection(
|
||||
start_time=1.05 + 0.1 * index,
|
||||
myo_score=score,
|
||||
)
|
||||
for index, score in enumerate(np.linspace(1, 0, 100))
|
||||
],
|
||||
)
|
||||
|
||||
# Create a dummy annotation
|
||||
gt = data.ClipAnnotation(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
create_annotation(
|
||||
start_time=1 + 0.1 * index,
|
||||
is_target=index % 2 == 0,
|
||||
class_name="pippip",
|
||||
)
|
||||
for index in range(100)
|
||||
]
|
||||
+ [
|
||||
create_annotation(
|
||||
start_time=1.05 + 0.1 * index,
|
||||
is_target=index % 3 == 0,
|
||||
class_name="myomyo",
|
||||
)
|
||||
for index in range(100)
|
||||
],
|
||||
)
|
||||
|
||||
evals = evaluator.evaluate([gt], [prediction])
|
||||
metrics = evaluator.compute_metrics(evals)
|
||||
|
||||
assert metrics["classification/average_precision/pippip"] == pytest.approx(
|
||||
0.5, abs=0.005
|
||||
)
|
||||
assert metrics["classification/average_precision/myomyo"] == pytest.approx(
|
||||
0.371, abs=0.005
|
||||
)
|
||||
assert metrics["classification/mean_average_precision"] == pytest.approx(
|
||||
(0.5 + 0.371) / 2, abs=0.005
|
||||
)
|
||||
50
tests/test_evaluate/test_tasks/test_detection.py
Normal file
50
tests/test_evaluate/test_tasks/test_detection.py
Normal file
@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.typing import ClipDetections
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
def test_detection(
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
create_detection,
|
||||
create_annotation,
|
||||
):
|
||||
config = DetectionTaskConfig.model_validate(
|
||||
{
|
||||
"name": "sound_event_detection",
|
||||
"metrics": [{"name": "average_precision"}],
|
||||
}
|
||||
)
|
||||
evaluator = build_task(config, targets=sample_targets)
|
||||
|
||||
# Create a dummy prediction
|
||||
prediction = ClipDetections(
|
||||
clip=clip,
|
||||
detections=[
|
||||
create_detection(start_time=1 + 0.1 * index, detection_score=score)
|
||||
for index, score in enumerate(np.linspace(0, 1, 100))
|
||||
],
|
||||
)
|
||||
|
||||
# Create a dummy annotation
|
||||
gt = data.ClipAnnotation(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
# Only half of the annotations are targets
|
||||
create_annotation(
|
||||
start_time=1 + 0.1 * index,
|
||||
is_target=index % 2 == 0,
|
||||
)
|
||||
for index in range(100)
|
||||
],
|
||||
)
|
||||
|
||||
# Run the task
|
||||
evals = evaluator.evaluate([gt], [prediction])
|
||||
metrics = evaluator.compute_metrics(evals)
|
||||
assert metrics["detection/average_precision"] == pytest.approx(0.5)
|
||||
@ -14,7 +14,7 @@ from batdetect2.postprocess.decoding import (
|
||||
get_generic_tags,
|
||||
get_prediction_features,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import Detection, TargetProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -209,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_raw_predictions() -> List[RawPrediction]:
|
||||
def sample_raw_predictions() -> List[Detection]:
|
||||
"""Manually crafted RawPrediction objects using the actual type."""
|
||||
|
||||
pred1_classes = xr.DataArray(
|
||||
@ -220,7 +220,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred1 = RawPrediction(
|
||||
pred1 = Detection(
|
||||
detection_score=0.9,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
@ -242,7 +242,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred2 = RawPrediction(
|
||||
pred2 = Detection(
|
||||
detection_score=0.8,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
@ -264,7 +264,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred3 = RawPrediction(
|
||||
pred3 = Detection(
|
||||
detection_score=0.15,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
@ -281,7 +281,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -324,7 +324,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_thresholding(
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -352,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_no_threshold(
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -380,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_top_class(
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -407,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -433,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_basic(
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_clip: data.Clip,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -485,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_passes_args(
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_clip: data.Clip,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
|
||||
Loading…
Reference in New Issue
Block a user