Run formatter

This commit is contained in:
mbsantiago 2025-12-12 21:28:28 +00:00
parent 531ff69974
commit 0adb58e039
64 changed files with 520 additions and 244 deletions

View File

@ -32,11 +32,11 @@ from batdetect2.train import (
) )
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
BatDetect2Prediction, ClipDetections,
EvaluatorProtocol, EvaluatorProtocol,
PostprocessorProtocol, PostprocessorProtocol,
PreprocessorProtocol, PreprocessorProtocol,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -110,7 +110,7 @@ class BatDetect2API:
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
save_predictions: bool = True, save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: ) -> Tuple[Dict[str, float], List[List[Detection]]]:
return evaluate( return evaluate(
self.model, self.model,
test_annotations, test_annotations,
@ -128,7 +128,7 @@ class BatDetect2API:
def evaluate_predictions( def evaluate_predictions(
self, self,
annotations: Sequence[data.ClipAnnotation], annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
output_dir: data.PathLike | None = None, output_dir: data.PathLike | None = None,
): ):
clip_evals = self.evaluator.evaluate( clip_evals = self.evaluator.evaluate(
@ -170,24 +170,24 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0) tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor) return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> BatDetect2Prediction: def process_file(self, audio_file: str) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False) recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording) wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav) detections = self.process_audio(wav)
return BatDetect2Prediction( return ClipDetections(
clip=data.Clip( clip=data.Clip(
uuid=recording.uuid, uuid=recording.uuid,
recording=recording, recording=recording,
start_time=0, start_time=0,
end_time=recording.duration, end_time=recording.duration,
), ),
predictions=detections, detections=detections,
) )
def process_audio( def process_audio(
self, self,
audio: np.ndarray, audio: np.ndarray,
) -> List[RawPrediction]: ) -> List[Detection]:
spec = self.generate_spectrogram(audio) spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec) return self.process_spectrogram(spec)
@ -195,7 +195,7 @@ class BatDetect2API:
self, self,
spec: torch.Tensor, spec: torch.Tensor,
start_time: float = 0, start_time: float = 0,
) -> List[RawPrediction]: ) -> List[Detection]:
if spec.ndim == 4 and spec.shape[0] > 1: if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.") raise ValueError("Batched spectrograms not supported.")
@ -214,7 +214,7 @@ class BatDetect2API:
def process_directory( def process_directory(
self, self,
audio_dir: data.PathLike, audio_dir: data.PathLike,
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
files = list(get_audio_files(audio_dir)) files = list(get_audio_files(audio_dir))
return self.process_files(files) return self.process_files(files)
@ -222,7 +222,7 @@ class BatDetect2API:
self, self,
audio_files: Sequence[data.PathLike], audio_files: Sequence[data.PathLike],
num_workers: int | None = None, num_workers: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
return process_file_list( return process_file_list(
self.model, self.model,
audio_files, audio_files,
@ -238,7 +238,7 @@ class BatDetect2API:
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
batch_size: int | None = None, batch_size: int | None = None,
num_workers: int | None = None, num_workers: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
return run_batch_inference( return run_batch_inference(
self.model, self.model,
clips, clips,
@ -252,7 +252,7 @@ class BatDetect2API:
def save_predictions( def save_predictions(
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
path: data.PathLike, path: data.PathLike,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
format: str | None = None, format: str | None = None,
@ -274,7 +274,7 @@ class BatDetect2API:
def load_predictions( def load_predictions(
self, self,
path: data.PathLike, path: data.PathLike,
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
return self.formatter.load(path) return self.formatter.load(path)
@classmethod @classmethod

View File

@ -1,4 +1,3 @@
import numpy as np import numpy as np
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
from pydantic import Field from pydantic import Field

View File

@ -1,4 +1,3 @@
import numpy as np import numpy as np
import torch import torch
import xarray as xr import xarray as xr

View File

@ -264,7 +264,14 @@ class Not:
SoundEventConditionConfig = Annotated[ SoundEventConditionConfig = Annotated[
HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig, HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| DurationConfig
| FrequencyConfig
| AllOfConfig
| AnyOfConfig
| NotConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -13,9 +13,9 @@ from batdetect2.data.predictions.base import (
) )
from batdetect2.targets import terms from batdetect2.targets import terms
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
OutputFormatterProtocol, OutputFormatterProtocol,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -113,7 +113,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
self.annotation_note = annotation_note self.annotation_note = annotation_note
def format( def format(
self, predictions: Sequence[BatDetect2Prediction] self, predictions: Sequence[ClipDetections]
) -> List[FileAnnotation]: ) -> List[FileAnnotation]:
return [ return [
self.format_prediction(prediction) for prediction in predictions self.format_prediction(prediction) for prediction in predictions
@ -164,14 +164,12 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
highest_scoring = max(annotations, key=lambda x: x["class_prob"]) highest_scoring = max(annotations, key=lambda x: x["class_prob"])
return highest_scoring["class"] return highest_scoring["class"]
def format_prediction( def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
self, prediction: BatDetect2Prediction
) -> FileAnnotation:
recording = prediction.clip.recording recording = prediction.clip.recording
annotations = [ annotations = [
self.format_sound_event_prediction(pred) self.format_sound_event_prediction(pred)
for pred in prediction.predictions for pred in prediction.detections
] ]
return FileAnnotation( return FileAnnotation(
@ -196,7 +194,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
) # type: ignore ) # type: ignore
def format_sound_event_prediction( def format_sound_event_prediction(
self, prediction: RawPrediction self, prediction: Detection
) -> Annotation: ) -> Annotation:
start_time, low_freq, end_time, high_freq = compute_bounds( start_time, low_freq, end_time, high_freq = compute_bounds(
prediction.geometry prediction.geometry

View File

@ -14,9 +14,9 @@ from batdetect2.data.predictions.base import (
prediction_formatters, prediction_formatters,
) )
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
OutputFormatterProtocol, OutputFormatterProtocol,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -29,7 +29,7 @@ class ParquetOutputConfig(BaseConfig):
include_geometry: bool = True include_geometry: bool = True
class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
@ -44,13 +44,13 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def format( def format(
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
return list(predictions) return list(predictions)
def save( def save(
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
path: data.PathLike, path: data.PathLike,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
) -> None: ) -> None:
@ -61,24 +61,26 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
# Ensure the file has .parquet extension if it's a file path # Ensure the file has .parquet extension if it's a file path
if path.suffix != ".parquet": if path.suffix != ".parquet":
# If it's a directory, we might want to save as a partitioned dataset or a single file inside # If it's a directory, we might want to save as a partitioned dataset or a single file inside
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet' # For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
if path.is_dir() or not path.suffix: if path.is_dir() or not path.suffix:
path = path / "predictions.parquet" path = path / "predictions.parquet"
rows = [] rows = []
for prediction in predictions: for prediction in predictions:
clip = prediction.clip clip = prediction.clip
recording = clip.recording recording = clip.recording
if audio_dir is not None: if audio_dir is not None:
recording = recording.model_copy( recording = recording.model_copy(
update=dict(path=make_path_relative(recording.path, audio_dir)) update=dict(
path=make_path_relative(recording.path, audio_dir)
)
) )
recording_json = recording.model_dump_json(exclude_none=True) recording_json = recording.model_dump_json(exclude_none=True)
for pred in prediction.predictions: for pred in prediction.detections:
row = { row = {
"clip_uuid": str(clip.uuid), "clip_uuid": str(clip.uuid),
"clip_start_time": clip.start_time, "clip_start_time": clip.start_time,
@ -96,18 +98,18 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
row["low_freq"] = low_freq row["low_freq"] = low_freq
row["end_time"] = end_time row["end_time"] = end_time
row["high_freq"] = high_freq row["high_freq"] = high_freq
# Store full geometry as JSON # Store full geometry as JSON
row["geometry"] = pred.geometry.model_dump_json() row["geometry"] = pred.geometry.model_dump_json()
if self.include_class_scores: if self.include_class_scores:
row["class_scores"] = pred.class_scores.tolist() row["class_scores"] = pred.class_scores.tolist()
if self.include_features: if self.include_features:
row["features"] = pred.features.tolist() row["features"] = pred.features.tolist()
rows.append(row) rows.append(row)
if not rows: if not rows:
logger.warning("No predictions to save.") logger.warning("No predictions to save.")
return return
@ -116,16 +118,16 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
logger.info(f"Saving {len(df)} predictions to {path}") logger.info(f"Saving {len(df)} predictions to {path}")
df.to_parquet(path, index=False) df.to_parquet(path, index=False)
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]: def load(self, path: data.PathLike) -> List[ClipDetections]:
path = Path(path) path = Path(path)
if path.is_dir(): if path.is_dir():
# Try to find parquet files # Try to find parquet files
files = list(path.glob("*.parquet")) files = list(path.glob("*.parquet"))
if not files: if not files:
return [] return []
# Read all and concatenate # Read all and concatenate
dfs = [pd.read_parquet(f) for f in files] dfs = [pd.read_parquet(f) for f in files]
df = pd.concat(dfs, ignore_index=True) df = pd.concat(dfs, ignore_index=True)
else: else:
df = pd.read_parquet(path) df = pd.read_parquet(path)
@ -133,20 +135,19 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
for _, row in df.iterrows(): for _, row in df.iterrows():
clip_uuid = row["clip_uuid"] clip_uuid = row["clip_uuid"]
if clip_uuid not in predictions_by_clip: if clip_uuid not in predictions_by_clip:
recording = data.Recording.model_validate_json(row["recording_info"]) recording = data.Recording.model_validate_json(
row["recording_info"]
)
clip = data.Clip( clip = data.Clip(
uuid=UUID(clip_uuid), uuid=UUID(clip_uuid),
recording=recording, recording=recording,
start_time=row["clip_start_time"], start_time=row["clip_start_time"],
end_time=row["clip_end_time"], end_time=row["clip_end_time"],
) )
predictions_by_clip[clip_uuid] = { predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
"clip": clip,
"preds": []
}
# Reconstruct geometry # Reconstruct geometry
if "geometry" in row and row["geometry"]: if "geometry" in row and row["geometry"]:
geometry = data.geometry_validate(row["geometry"]) geometry = data.geometry_validate(row["geometry"])
@ -156,14 +157,20 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
row["start_time"], row["start_time"],
row["low_freq"], row["low_freq"],
row["end_time"], row["end_time"],
row["high_freq"] row["high_freq"],
] ]
) )
class_scores = np.array(row["class_scores"]) if "class_scores" in row else np.zeros(len(self.targets.class_names)) class_scores = (
features = np.array(row["features"]) if "features" in row else np.zeros(0) np.array(row["class_scores"])
if "class_scores" in row
else np.zeros(len(self.targets.class_names))
)
features = (
np.array(row["features"]) if "features" in row else np.zeros(0)
)
pred = RawPrediction( pred = Detection(
geometry=geometry, geometry=geometry,
detection_score=row["detection_score"], detection_score=row["detection_score"],
class_scores=class_scores, class_scores=class_scores,
@ -174,12 +181,11 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
results = [] results = []
for clip_data in predictions_by_clip.values(): for clip_data in predictions_by_clip.values():
results.append( results.append(
BatDetect2Prediction( ClipDetections(
clip=clip_data["clip"], clip=clip_data["clip"], detections=clip_data["preds"]
predictions=clip_data["preds"]
) )
) )
return results return results
@prediction_formatters.register(ParquetOutputConfig) @prediction_formatters.register(ParquetOutputConfig)

View File

@ -15,9 +15,9 @@ from batdetect2.data.predictions.base import (
prediction_formatters, prediction_formatters,
) )
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
OutputFormatterProtocol, OutputFormatterProtocol,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -30,7 +30,7 @@ class RawOutputConfig(BaseConfig):
include_geometry: bool = True include_geometry: bool = True
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): class RawFormatter(OutputFormatterProtocol[ClipDetections]):
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
@ -47,13 +47,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def format( def format(
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
return list(predictions) return list(predictions)
def save( def save(
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
path: data.PathLike, path: data.PathLike,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
) -> None: ) -> None:
@ -68,10 +68,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
dataset = self.pred_to_xr(prediction, audio_dir) dataset = self.pred_to_xr(prediction, audio_dir)
dataset.to_netcdf(path / f"{clip.uuid}.nc") dataset.to_netcdf(path / f"{clip.uuid}.nc")
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]: def load(self, path: data.PathLike) -> List[ClipDetections]:
path = Path(path) path = Path(path)
files = list(path.glob("*.nc")) files = list(path.glob("*.nc"))
predictions: List[BatDetect2Prediction] = [] predictions: List[ClipDetections] = []
for filepath in files: for filepath in files:
logger.debug(f"Loading clip predictions {filepath}") logger.debug(f"Loading clip predictions {filepath}")
@ -83,7 +83,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def pred_to_xr( def pred_to_xr(
self, self,
prediction: BatDetect2Prediction, prediction: ClipDetections,
audio_dir: data.PathLike | None = None, audio_dir: data.PathLike | None = None,
) -> xr.Dataset: ) -> xr.Dataset:
clip = prediction.clip clip = prediction.clip
@ -97,7 +97,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
data = defaultdict(list) data = defaultdict(list)
for pred in prediction.predictions: for pred in prediction.detections:
detection_id = str(uuid4()) detection_id = str(uuid4())
data["detection_id"].append(detection_id) data["detection_id"].append(detection_id)
@ -167,7 +167,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
}, },
) )
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction: def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
clip_data = dataset clip_data = dataset
clip_id = clip_data.clip_id.item() clip_id = clip_data.clip_id.item()
@ -219,7 +219,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
features = np.zeros(0) features = np.zeros(0)
sound_events.append( sound_events.append(
RawPrediction( Detection(
geometry=geometry, geometry=geometry,
detection_score=score, detection_score=score,
class_scores=class_scores, class_scores=class_scores,
@ -227,9 +227,9 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
) )
) )
return BatDetect2Prediction( return ClipDetections(
clip=clip, clip=clip,
predictions=sound_events, detections=sound_events,
) )
@prediction_formatters.register(RawOutputConfig) @prediction_formatters.register(RawOutputConfig)

View File

@ -9,9 +9,9 @@ from batdetect2.data.predictions.base import (
prediction_formatters, prediction_formatters,
) )
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
OutputFormatterProtocol, OutputFormatterProtocol,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -35,7 +35,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
def format( def format(
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> List[data.ClipPrediction]: ) -> List[data.ClipPrediction]:
return [ return [
self.format_prediction(prediction) for prediction in predictions self.format_prediction(prediction) for prediction in predictions
@ -63,20 +63,20 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
def format_prediction( def format_prediction(
self, self,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> data.ClipPrediction: ) -> data.ClipPrediction:
recording = prediction.clip.recording recording = prediction.clip.recording
return data.ClipPrediction( return data.ClipPrediction(
clip=prediction.clip, clip=prediction.clip,
sound_events=[ sound_events=[
self.format_sound_event_prediction(pred, recording) self.format_sound_event_prediction(pred, recording)
for pred in prediction.predictions for pred in prediction.detections
], ],
) )
def format_sound_event_prediction( def format_sound_event_prediction(
self, self,
prediction: RawPrediction, prediction: Detection,
recording: data.Recording, recording: data.Recording,
) -> data.SoundEventPrediction: ) -> data.SoundEventPrediction:
return data.SoundEventPrediction( return data.SoundEventPrediction(
@ -89,7 +89,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
) )
def get_sound_event_tags( def get_sound_event_tags(
self, prediction: RawPrediction self, prediction: Detection
) -> List[data.PredictedTag]: ) -> List[data.PredictedTag]:
sorted_indices = np.argsort(prediction.class_scores)[::-1] sorted_indices = np.argsort(prediction.class_scores)[::-1]

View File

@ -221,7 +221,11 @@ class ApplyAll:
SoundEventTransformConfig = Annotated[ SoundEventTransformConfig = Annotated[
SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig, SetFrequencyBoundConfig
| ReplaceTagConfig
| MapTagValueConfig
| ApplyIfConfig
| ApplyAllConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -11,7 +11,7 @@ from soundevent.geometry import (
) )
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import AffinityFunction, RawPrediction from batdetect2.typing import AffinityFunction, Detection
affinity_functions: Registry[AffinityFunction, []] = Registry( affinity_functions: Registry[AffinityFunction, []] = Registry(
"affinity_function" "affinity_function"
@ -42,7 +42,7 @@ class TimeAffinity(AffinityFunction):
def __call__( def __call__(
self, self,
detection: RawPrediction, detection: Detection,
ground_truth: data.SoundEventAnnotation, ground_truth: data.SoundEventAnnotation,
) -> float: ) -> float:
target_geometry = ground_truth.sound_event.geometry target_geometry = ground_truth.sound_event.geometry
@ -77,7 +77,7 @@ class IntervalIOU(AffinityFunction):
def __call__( def __call__(
self, self,
detection: RawPrediction, detection: Detection,
ground_truth: data.SoundEventAnnotation, ground_truth: data.SoundEventAnnotation,
) -> float: ) -> float:
target_geometry = ground_truth.sound_event.geometry target_geometry = ground_truth.sound_event.geometry
@ -120,7 +120,7 @@ class BBoxIOU(AffinityFunction):
def __call__( def __call__(
self, self,
prediction: RawPrediction, prediction: Detection,
gt: data.SoundEventAnnotation, gt: data.SoundEventAnnotation,
): ):
target_geometry = gt.sound_event.geometry target_geometry = gt.sound_event.geometry
@ -168,7 +168,7 @@ class GeometricIOU(AffinityFunction):
def __call__( def __call__(
self, self,
prediction: RawPrediction, prediction: Detection,
gt: data.SoundEventAnnotation, gt: data.SoundEventAnnotation,
): ):
target_geometry = gt.sound_event.geometry target_geometry = gt.sound_event.geometry

View File

@ -12,7 +12,7 @@ from batdetect2.logging import build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing.postprocess import RawPrediction from batdetect2.typing import Detection
if TYPE_CHECKING: if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
@ -38,7 +38,7 @@ def evaluate(
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: ) -> Tuple[Dict[str, float], List[List[Detection]]]:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config() config = config or BatDetect2Config()

View File

@ -7,7 +7,7 @@ from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task from batdetect2.evaluate.tasks import build_task
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol from batdetect2.typing import EvaluatorProtocol, TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.typing.postprocess import ClipDetections
__all__ = [ __all__ = [
"Evaluator", "Evaluator",
@ -27,7 +27,7 @@ class Evaluator:
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> List[Any]: ) -> List[Any]:
return [ return [
task.evaluate(clip_annotations, predictions) for task in self.tasks task.evaluate(clip_annotations, predictions) for task in self.tasks

View File

@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.typing.postprocess import ClipDetections
class EvaluationModule(LightningModule): class EvaluationModule(LightningModule):
@ -24,7 +24,7 @@ class EvaluationModule(LightningModule):
self.evaluator = evaluator self.evaluator = evaluator
self.clip_annotations: List[data.ClipAnnotation] = [] self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[BatDetect2Prediction] = [] self.predictions: List[ClipDetections] = []
def test_step(self, batch: TestExample, batch_idx: int): def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset() dataset = self.get_dataset()
@ -39,9 +39,9 @@ class EvaluationModule(LightningModule):
start_times=[ca.clip.start_time for ca in clip_annotations], start_times=[ca.clip.start_time for ca in clip_annotations],
) )
predictions = [ predictions = [
BatDetect2Prediction( ClipDetections(
clip=clip_annotation.clip, clip=clip_annotation.clip,
predictions=to_raw_predictions( detections=to_raw_predictions(
clip_dets.numpy(), clip_dets.numpy(),
targets=self.evaluator.targets, targets=self.evaluator.targets,
), ),

View File

@ -21,7 +21,7 @@ from batdetect2.evaluate.metrics.common import (
average_precision, average_precision,
compute_precision_recall, compute_precision_recall,
) )
from batdetect2.typing import RawPrediction, TargetProtocol from batdetect2.typing import Detection, TargetProtocol
__all__ = [ __all__ = [
"ClassificationMetric", "ClassificationMetric",
@ -35,7 +35,7 @@ __all__ = [
class MatchEval: class MatchEval:
clip: data.Clip clip: data.Clip
gt: data.SoundEventAnnotation | None gt: data.SoundEventAnnotation | None
pred: RawPrediction | None pred: Detection | None
is_prediction: bool is_prediction: bool
is_ground_truth: bool is_ground_truth: bool

View File

@ -159,7 +159,10 @@ class ClipDetectionPrecision:
ClipDetectionMetricConfig = Annotated[ ClipDetectionMetricConfig = Annotated[
ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig, ClipDetectionAveragePrecisionConfig
| ClipDetectionROCAUCConfig
| ClipDetectionRecallConfig
| ClipDetectionPrecisionConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -15,7 +15,7 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction from batdetect2.typing import Detection
__all__ = [ __all__ = [
"DetectionMetricConfig", "DetectionMetricConfig",
@ -27,7 +27,7 @@ __all__ = [
@dataclass @dataclass
class MatchEval: class MatchEval:
gt: data.SoundEventAnnotation | None gt: data.SoundEventAnnotation | None
pred: RawPrediction | None pred: Detection | None
is_prediction: bool is_prediction: bool
is_ground_truth: bool is_ground_truth: bool

View File

@ -15,7 +15,7 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction from batdetect2.typing import Detection
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
@ -29,7 +29,7 @@ __all__ = [
class MatchEval: class MatchEval:
clip: data.Clip clip: data.Clip
gt: data.SoundEventAnnotation | None gt: data.SoundEventAnnotation | None
pred: RawPrediction | None pred: Detection | None
is_ground_truth: bool is_ground_truth: bool
is_generic: bool is_generic: bool
@ -298,7 +298,11 @@ class BalancedAccuracy:
TopClassMetricConfig = Annotated[ TopClassMetricConfig = Annotated[
TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig, TopClassAveragePrecisionConfig
| TopClassROCAUCConfig
| TopClassRecallConfig
| TopClassPrecisionConfig
| BalancedAccuracyConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,4 +1,3 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure from matplotlib.figure import Figure

View File

@ -322,7 +322,10 @@ class ROCCurve(BasePlot):
ClassificationPlotConfig = Annotated[ ClassificationPlotConfig = Annotated[
PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig, PRCurveConfig
| ROCCurveConfig
| ThresholdPrecisionCurveConfig
| ThresholdRecallCurveConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -290,7 +290,10 @@ class ExampleDetectionPlot(BasePlot):
DetectionPlotConfig = Annotated[ DetectionPlotConfig = Annotated[
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig, PRCurveConfig
| ROCCurveConfig
| ScoreDistributionPlotConfig
| ExampleDetectionPlotConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -346,7 +346,10 @@ class ExampleClassificationPlot(BasePlot):
TopClassPlotConfig = Annotated[ TopClassPlotConfig = Annotated[
PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig, PRCurveConfig
| ROCCurveConfig
| ConfusionMatrixConfig
| ExampleClassificationPlotConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -403,7 +406,8 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
return matches return matches
indices, pred_scores = zip( indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)], strict=False *[(index, match.score) for index, match in enumerate(matches)],
strict=False,
) )
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")

View File

@ -13,7 +13,7 @@ from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
EvaluatorProtocol, EvaluatorProtocol,
TargetProtocol, TargetProtocol,
) )
@ -45,7 +45,7 @@ def build_task(
def evaluate_task( def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
task: Optional["str"] = None, task: Optional["str"] = None,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
config: TaskConfig | dict | None = None, config: TaskConfig | dict | None = None,

View File

@ -22,9 +22,9 @@ from batdetect2.evaluate.affinity import (
) )
from batdetect2.typing import ( from batdetect2.typing import (
AffinityFunction, AffinityFunction,
BatDetect2Prediction, ClipDetections,
EvaluatorProtocol, EvaluatorProtocol,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -96,7 +96,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> List[T_Output]: ) -> List[T_Output]:
return [ return [
self.evaluate_clip(clip_annotation, preds) self.evaluate_clip(clip_annotation, preds)
@ -108,7 +108,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> T_Output: ... ) -> T_Output: ...
def include_sound_event_annotation( def include_sound_event_annotation(
@ -128,7 +128,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def include_prediction( def include_prediction(
self, self,
prediction: RawPrediction, prediction: Detection,
clip: data.Clip, clip: data.Clip,
) -> bool: ) -> bool:
return is_in_bounds( return is_in_bounds(

View File

@ -22,8 +22,8 @@ from batdetect2.evaluate.tasks.base import (
tasks_registry, tasks_registry,
) )
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -51,13 +51,13 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
preds = [ preds = [
pred pred
for pred in prediction.predictions for pred in prediction.detections
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]
@ -140,7 +140,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
) )
def get_class_score(pred: RawPrediction, class_idx: int) -> float: def get_class_score(pred: Detection, class_idx: int) -> float:
return pred.class_scores[class_idx] return pred.class_scores[class_idx]

View File

@ -19,7 +19,7 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig, BaseTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import BatDetect2Prediction, TargetProtocol from batdetect2.typing import ClipDetections, TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig): class ClipClassificationTaskConfig(BaseTaskConfig):
@ -37,7 +37,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -54,7 +54,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
gt_classes.add(class_name) gt_classes.add(class_name)
pred_scores = defaultdict(float) pred_scores = defaultdict(float)
for pred in prediction.predictions: for pred in prediction.detections:
if not self.include_prediction(pred, clip): if not self.include_prediction(pred, clip):
continue continue

View File

@ -18,7 +18,7 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig, BaseTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import BatDetect2Prediction, TargetProtocol from batdetect2.typing import ClipDetections, TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig): class ClipDetectionTaskConfig(BaseTaskConfig):
@ -36,7 +36,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -46,7 +46,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
) )
pred_score = 0 pred_score = 0
for pred in prediction.predictions: for pred in prediction.detections:
if not self.include_prediction(pred, clip): if not self.include_prediction(pred, clip):
continue continue

View File

@ -21,7 +21,7 @@ from batdetect2.evaluate.tasks.base import (
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.typing.postprocess import ClipDetections
class DetectionTaskConfig(BaseSEDTaskConfig): class DetectionTaskConfig(BaseSEDTaskConfig):
@ -37,7 +37,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -48,7 +48,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
] ]
preds = [ preds = [
pred pred
for pred in prediction.predictions for pred in prediction.detections
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]

View File

@ -20,7 +20,7 @@ from batdetect2.evaluate.tasks.base import (
BaseSEDTaskConfig, BaseSEDTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import BatDetect2Prediction, TargetProtocol from batdetect2.typing import ClipDetections, TargetProtocol
class TopClassDetectionTaskConfig(BaseSEDTaskConfig): class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
@ -36,7 +36,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -47,7 +47,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
] ]
preds = [ preds = [
pred pred
for pred in prediction.predictions for pred in prediction.detections
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]

View File

@ -88,7 +88,8 @@ def select_device(warn=True) -> str:
if warn: if warn:
warnings.warn( warnings.warn(
"No GPU available, using the CPU instead. Please consider using a GPU " "No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training.", stacklevel=2 "to speed up training.",
stacklevel=2,
) )
return "cpu" return "cpu"

View File

@ -10,7 +10,7 @@ from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.preprocess.preprocessor import build_preprocessor from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.typing.postprocess import ClipDetections
if TYPE_CHECKING: if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
@ -30,7 +30,7 @@ def run_batch_inference(
config: Optional["BatDetect2Config"] = None, config: Optional["BatDetect2Config"] = None,
num_workers: int | None = None, num_workers: int | None = None,
batch_size: int | None = None, batch_size: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config() config = config or BatDetect2Config()
@ -70,7 +70,7 @@ def process_file_list(
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: int | None = None, num_workers: int | None = None,
) -> List[BatDetect2Prediction]: ) -> List[ClipDetections]:
clip_config = config.inference.clipping clip_config = config.inference.clipping
clips = get_clips_from_files( clips = get_clips_from_files(
paths, paths,

View File

@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.typing.postprocess import ClipDetections
class InferenceModule(LightningModule): class InferenceModule(LightningModule):
@ -19,7 +19,7 @@ class InferenceModule(LightningModule):
batch: DatasetItem, batch: DatasetItem,
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> Sequence[BatDetect2Prediction]: ) -> Sequence[ClipDetections]:
dataset = self.get_dataset() dataset = self.get_dataset()
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx] clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
@ -32,9 +32,9 @@ class InferenceModule(LightningModule):
) )
predictions = [ predictions = [
BatDetect2Prediction( ClipDetections(
clip=clip, clip=clip,
predictions=to_raw_predictions( detections=to_raw_predictions(
clip_dets.numpy(), clip_dets.numpy(),
targets=self.model.targets, targets=self.model.targets,
), ),

View File

@ -76,7 +76,10 @@ class MLFlowLoggerConfig(BaseLoggerConfig):
LoggerConfig = Annotated[ LoggerConfig = Annotated[
DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig, DVCLiveConfig
| CSVLoggerConfig
| TensorBoardLoggerConfig
| MLFlowLoggerConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,4 +1,3 @@
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config

View File

@ -41,7 +41,10 @@ __all__ = [
] ]
DecoderLayerConfig = Annotated[ DecoderLayerConfig = Annotated[
ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig, ConvConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| LayerGroupConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Decoder.""" """Type alias for the discriminated union of block configs usable in Decoder."""

View File

@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
""" """
import torch import torch
from loguru import logger from loguru import logger

View File

@ -43,7 +43,10 @@ __all__ = [
] ]
EncoderLayerConfig = Annotated[ EncoderLayerConfig = Annotated[
ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig, ConvConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| LayerGroupConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Encoder.""" """Type alias for the discriminated union of block configs usable in Encoder."""

View File

@ -1,4 +1,3 @@
from matplotlib import axes, patches from matplotlib import axes, patches
from soundevent.plot import plot_geometry from soundevent.plot import plot_geometry

View File

@ -36,7 +36,9 @@ def plot_match_gallery(
sharey="row", sharey="row",
) )
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False): for tp_ax, tp_match in zip(
axes[0], true_positives[:n_examples], strict=False
):
try: try:
plot_true_positive_match( plot_true_positive_match(
tp_match, tp_match,
@ -53,7 +55,9 @@ def plot_match_gallery(
): ):
continue continue
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False): for fp_ax, fp_match in zip(
axes[1], false_positives[:n_examples], strict=False
):
try: try:
plot_false_positive_match( plot_false_positive_match(
fp_match, fp_match,
@ -70,7 +74,9 @@ def plot_match_gallery(
): ):
continue continue
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False): for fn_ax, fn_match in zip(
axes[2], false_negatives[:n_examples], strict=False
):
try: try:
plot_false_negative_match( plot_false_negative_match(
fn_match, fn_match,
@ -87,7 +93,9 @@ def plot_match_gallery(
): ):
continue continue
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False): for ct_ax, ct_match in zip(
axes[3], cross_triggers[:n_examples], strict=False
):
try: try:
plot_cross_trigger_match( plot_cross_trigger_match(
ct_match, ct_match,

View File

@ -8,7 +8,7 @@ from batdetect2.plotting.clips import plot_clip
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
PreprocessorProtocol, PreprocessorProtocol,
RawPrediction, Detection,
) )
__all__ = [ __all__ = [
@ -22,7 +22,7 @@ __all__ = [
class MatchProtocol(Protocol): class MatchProtocol(Protocol):
clip: data.Clip clip: data.Clip
gt: data.SoundEventAnnotation | None gt: data.SoundEventAnnotation | None
pred: RawPrediction | None pred: Detection | None
score: float score: float
true_class: str | None true_class: str | None
@ -341,4 +341,3 @@ def plot_cross_trigger_match(
ax.set_title("Cross Trigger") ax.set_title("Cross Trigger")
return ax return ax

View File

@ -0,0 +1,6 @@
from batdetect2.typing import ClipDetections
class ClipTransform:
def __init__(self, clip: ClipDetections):
pass

View File

@ -1,4 +1,3 @@
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -7,7 +7,7 @@ from soundevent import data
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
ClipDetectionsArray, ClipDetectionsArray,
RawPrediction, Detection,
) )
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
@ -30,7 +30,7 @@ decoding.
def to_raw_predictions( def to_raw_predictions(
detections: ClipDetectionsArray, detections: ClipDetectionsArray,
targets: TargetProtocol, targets: TargetProtocol,
) -> List[RawPrediction]: ) -> List[Detection]:
predictions = [] predictions = []
for score, class_scores, time, freq, dims, feats in zip( for score, class_scores, time, freq, dims, feats in zip(
@ -39,7 +39,8 @@ def to_raw_predictions(
detections.times, detections.times,
detections.frequencies, detections.frequencies,
detections.sizes, detections.sizes,
detections.features, strict=False, detections.features,
strict=False,
): ):
highest_scoring_class = targets.class_names[class_scores.argmax()] highest_scoring_class = targets.class_names[class_scores.argmax()]
@ -50,7 +51,7 @@ def to_raw_predictions(
) )
predictions.append( predictions.append(
RawPrediction( Detection(
detection_score=score, detection_score=score,
geometry=geom, geometry=geom,
class_scores=class_scores, class_scores=class_scores,
@ -62,7 +63,7 @@ def to_raw_predictions(
def convert_raw_predictions_to_clip_prediction( def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction], raw_predictions: List[Detection],
clip: data.Clip, clip: data.Clip,
targets: TargetProtocol, targets: TargetProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
@ -85,7 +86,7 @@ def convert_raw_predictions_to_clip_prediction(
def convert_raw_prediction_to_sound_event_prediction( def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction, raw_prediction: Detection,
recording: data.Recording, recording: data.Recording,
targets: TargetProtocol, targets: TargetProtocol,
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,

View File

@ -1,4 +1,3 @@
import torch import torch
from loguru import logger from loguru import logger

View File

@ -377,7 +377,10 @@ class PeakNormalize(torch.nn.Module):
SpectrogramTransform = Annotated[ SpectrogramTransform = Annotated[
PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig, PcenConfig
| ScaleAmplitudeConfig
| SpectralMeanSubstractionConfig
| PeakNormalizeConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -394,7 +394,8 @@ class MaskTime(torch.nn.Module):
size=num_masks, size=num_masks,
) )
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False) (start, start + size)
for start, size in zip(mask_start, mask_size, strict=False)
] ]
return mask_time(spec, masks), clip_annotation return mask_time(spec, masks), clip_annotation
@ -460,7 +461,8 @@ class MaskFrequency(torch.nn.Module):
size=num_masks, size=num_masks,
) )
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False) (start, start + size)
for start, size in zip(mask_start, mask_size, strict=False)
] ]
return mask_frequency(spec, masks), clip_annotation return mask_frequency(spec, masks), clip_annotation
@ -498,7 +500,12 @@ SpectrogramAugmentationConfig = Annotated[
] ]
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[
MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig, MixAudioConfig
| AddEchoConfig
| ScaleVolumeConfig
| WarpConfig
| MaskFrequencyConfig
| MaskTimeConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of individual augmentation config.""" """Type alias for the discriminated union of individual augmentation config."""

View File

@ -10,7 +10,7 @@ from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
EvaluatorProtocol, EvaluatorProtocol,
ModelOutput, ModelOutput,
TrainExample, TrainExample,
@ -24,7 +24,7 @@ class ValidationMetrics(Callback):
self.evaluator = evaluator self.evaluator = evaluator
self._clip_annotations: List[data.ClipAnnotation] = [] self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[BatDetect2Prediction] = [] self._predictions: List[ClipDetections] = []
def get_dataset(self, trainer: Trainer) -> ValidationDataset: def get_dataset(self, trainer: Trainer) -> ValidationDataset:
dataloaders = trainer.val_dataloaders dataloaders = trainer.val_dataloaders
@ -100,9 +100,9 @@ class ValidationMetrics(Callback):
start_times=[ca.clip.start_time for ca in clip_annotations], start_times=[ca.clip.start_time for ca in clip_annotations],
) )
predictions = [ predictions = [
BatDetect2Prediction( ClipDetections(
clip=clip_annotation.clip, clip=clip_annotation.clip,
predictions=to_raw_predictions( detections=to_raw_predictions(
clip_dets.numpy(), targets=model.targets clip_dets.numpy(), targets=model.targets
), ),
) )

View File

@ -1,4 +1,3 @@
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -18,7 +18,6 @@ The primary entry points are:
- `LossConfig`: The Pydantic model for configuring loss weights and parameters. - `LossConfig`: The Pydantic model for configuring loss weights and parameters.
""" """
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -10,11 +10,11 @@ from batdetect2.typing.evaluate import (
) )
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
BatDetect2Prediction, ClipDetections,
ClipDetectionsTensor, ClipDetectionsTensor,
GeometryDecoder, GeometryDecoder,
PostprocessorProtocol, PostprocessorProtocol,
RawPrediction, Detection,
) )
from batdetect2.typing.preprocess import ( from batdetect2.typing.preprocess import (
AudioLoader, AudioLoader,
@ -44,7 +44,7 @@ __all__ = [
"AudioLoader", "AudioLoader",
"Augmentation", "Augmentation",
"BackboneModel", "BackboneModel",
"BatDetect2Prediction", "ClipDetections",
"ClipDetectionsTensor", "ClipDetectionsTensor",
"ClipLabeller", "ClipLabeller",
"ClipMatches", "ClipMatches",
@ -65,7 +65,7 @@ __all__ = [
"PostprocessorProtocol", "PostprocessorProtocol",
"PreprocessorProtocol", "PreprocessorProtocol",
"ROITargetMapper", "ROITargetMapper",
"RawPrediction", "Detection",
"Size", "Size",
"SoundEventDecoder", "SoundEventDecoder",
"SoundEventEncoder", "SoundEventEncoder",

View File

@ -2,7 +2,7 @@ from typing import Generic, List, Protocol, Sequence, TypeVar
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.typing.postprocess import ClipDetections
__all__ = [ __all__ = [
"OutputFormatterProtocol", "OutputFormatterProtocol",
@ -12,9 +12,7 @@ T = TypeVar("T")
class OutputFormatterProtocol(Protocol, Generic[T]): class OutputFormatterProtocol(Protocol, Generic[T]):
def format( def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ...
self, predictions: Sequence[BatDetect2Prediction]
) -> List[T]: ...
def save( def save(
self, self,

View File

@ -13,7 +13,7 @@ from typing import (
from matplotlib.figure import Figure from matplotlib.figure import Figure
from soundevent import data from soundevent import data
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction from batdetect2.typing.postprocess import ClipDetections, Detection
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
@ -84,7 +84,7 @@ Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
class AffinityFunction(Protocol): class AffinityFunction(Protocol):
def __call__( def __call__(
self, self,
detection: RawPrediction, detection: Detection,
ground_truth: data.SoundEventAnnotation, ground_truth: data.SoundEventAnnotation,
) -> float: ... ) -> float: ...
@ -93,7 +93,7 @@ class MetricsProtocol(Protocol):
def __call__( def __call__(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]], predictions: Sequence[Sequence[Detection]],
) -> Dict[str, float]: ... ) -> Dict[str, float]: ...
@ -101,7 +101,7 @@ class PlotterProtocol(Protocol):
def __call__( def __call__(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]], predictions: Sequence[Sequence[Detection]],
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[Tuple[str, Figure]]: ...
@ -114,7 +114,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ... ) -> EvaluationOutput: ...
def compute_metrics( def compute_metrics(

View File

@ -22,7 +22,7 @@ from batdetect2.typing.models import ModelOutput
from batdetect2.typing.targets import Position, Size from batdetect2.typing.targets import Position, Size
__all__ = [ __all__ = [
"RawPrediction", "Detection",
"PostprocessorProtocol", "PostprocessorProtocol",
"GeometryDecoder", "GeometryDecoder",
] ]
@ -46,7 +46,8 @@ class GeometryDecoder(Protocol):
) -> data.Geometry: ... ) -> data.Geometry: ...
class RawPrediction(NamedTuple): @dataclass
class Detection:
geometry: data.Geometry geometry: data.Geometry
detection_score: float detection_score: float
class_scores: np.ndarray class_scores: np.ndarray
@ -82,9 +83,16 @@ class ClipDetectionsTensor(NamedTuple):
@dataclass @dataclass
class BatDetect2Prediction: class ClipDetections:
clip: data.Clip clip: data.Clip
predictions: List[RawPrediction] detections: List[Detection]
@dataclass
class ClipPrediction:
clip: data.Clip
detection_score: float
class_scores: np.ndarray
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):

View File

@ -220,7 +220,8 @@ def get_annotations_from_preds(
predictions["high_freqs"], predictions["high_freqs"],
class_ind_best, class_ind_best,
class_prob_best, class_prob_best,
predictions["det_probs"], strict=False, predictions["det_probs"],
strict=False,
) )
] ]
return annotations return annotations

View File

@ -87,9 +87,7 @@ def save_ann_spec(
y_extent = [0, duration, min_freq, max_freq] y_extent = [0, duration, min_freq, max_freq]
plt.close("all") plt.close("all")
plt.figure( plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100)
0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
)
plt.imshow( plt.imshow(
spec, spec,
aspect="auto", aspect="auto",

View File

@ -1,4 +1,3 @@
import numpy as np import numpy as np
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F

View File

@ -10,8 +10,8 @@ from batdetect2.data.predictions import (
build_output_formatter, build_output_formatter,
) )
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -31,7 +31,7 @@ def test_roundtrip(
tmp_path: Path, tmp_path: Path,
): ):
detections = [ detections = [
RawPrediction( Detection(
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4])) coordinates=list(np.random.uniform(size=[4]))
), ),
@ -44,7 +44,7 @@ def test_roundtrip(
for _ in range(10) for _ in range(10)
] ]
prediction = BatDetect2Prediction(clip=clip, predictions=detections) prediction = ClipDetections(clip=clip, detections=detections)
path = tmp_path / "predictions.parquet" path = tmp_path / "predictions.parquet"
@ -58,7 +58,7 @@ def test_roundtrip(
assert recovered[0].clip == prediction.clip assert recovered[0].clip == prediction.clip
for recovered_prediction, detection in zip( for recovered_prediction, detection in zip(
recovered[0].predictions, detections recovered[0].detections, detections, strict=True
): ):
assert ( assert (
recovered_prediction.detection_score == detection.detection_score recovered_prediction.detection_score == detection.detection_score
@ -79,9 +79,9 @@ def test_multiple_clips(
): ):
# Create a second clip # Create a second clip
clip2 = clip.model_copy(update={"uuid": uuid4()}) clip2 = clip.model_copy(update={"uuid": uuid4()})
detections1 = [ detections1 = [
RawPrediction( Detection(
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4])) coordinates=list(np.random.uniform(size=[4]))
), ),
@ -92,9 +92,9 @@ def test_multiple_clips(
features=np.random.uniform(size=32), features=np.random.uniform(size=32),
) )
] ]
detections2 = [ detections2 = [
RawPrediction( Detection(
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4])) coordinates=list(np.random.uniform(size=[4]))
), ),
@ -107,19 +107,19 @@ def test_multiple_clips(
] ]
predictions = [ predictions = [
BatDetect2Prediction(clip=clip, predictions=detections1), ClipDetections(clip=clip, detections=detections1),
BatDetect2Prediction(clip=clip2, predictions=detections2), ClipDetections(clip=clip2, detections=detections2),
] ]
path = tmp_path / "multi_predictions.parquet" path = tmp_path / "multi_predictions.parquet"
sample_formatter.save(predictions=predictions, path=path) sample_formatter.save(predictions=predictions, path=path)
recovered = sample_formatter.load(path=path) recovered = sample_formatter.load(path=path)
assert len(recovered) == 2 assert len(recovered) == 2
# Order might not be preserved if we don't sort, but implementation appends so it should be # Order might not be preserved if we don't sort, but implementation appends so it should be
# However, let's sort by clip uuid to be safe if needed, or just check existence # However, let's sort by clip uuid to be safe if needed, or just check existence
recovered_uuids = {p.clip.uuid for p in recovered} recovered_uuids = {p.clip.uuid for p in recovered}
expected_uuids = {clip.uuid, clip2.uuid} expected_uuids = {clip.uuid, clip2.uuid}
assert recovered_uuids == expected_uuids assert recovered_uuids == expected_uuids
@ -133,16 +133,18 @@ def test_complex_geometry(
): ):
# Create a polygon geometry # Create a polygon geometry
polygon = data.Polygon( polygon = data.Polygon(
coordinates=[[ coordinates=[
[0.0, 10000.0], [
[0.1, 20000.0], [0.0, 10000.0],
[0.2, 10000.0], [0.1, 20000.0],
[0.0, 10000.0], [0.2, 10000.0],
]] [0.0, 10000.0],
]
]
) )
detections = [ detections = [
RawPrediction( Detection(
geometry=polygon, geometry=polygon,
detection_score=0.95, detection_score=0.95,
class_scores=np.random.uniform( class_scores=np.random.uniform(
@ -152,18 +154,18 @@ def test_complex_geometry(
) )
] ]
prediction = BatDetect2Prediction(clip=clip, predictions=detections) prediction = ClipDetections(clip=clip, detections=detections)
path = tmp_path / "complex_geometry.parquet" path = tmp_path / "complex_geometry.parquet"
sample_formatter.save(predictions=[prediction], path=path) sample_formatter.save(predictions=[prediction], path=path)
recovered = sample_formatter.load(path=path) recovered = sample_formatter.load(path=path)
assert len(recovered) == 1 assert len(recovered) == 1
assert len(recovered[0].predictions) == 1 assert len(recovered[0].detections) == 1
recovered_pred = recovered[0].predictions[0] recovered_pred = recovered[0].detections[0]
# Check if geometry is recovered correctly as a Polygon # Check if geometry is recovered correctly as a Polygon
assert isinstance(recovered_pred.geometry, data.Polygon) assert isinstance(recovered_pred.geometry, data.Polygon)
assert recovered_pred.geometry == polygon assert recovered_pred.geometry == polygon

View File

@ -6,8 +6,8 @@ from soundevent import data
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, ClipDetections,
RawPrediction, Detection,
TargetProtocol, TargetProtocol,
) )
@ -27,7 +27,7 @@ def test_roundtrip(
tmp_path: Path, tmp_path: Path,
): ):
detections = [ detections = [
RawPrediction( Detection(
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4])) coordinates=list(np.random.uniform(size=[4]))
), ),
@ -40,7 +40,7 @@ def test_roundtrip(
for _ in range(10) for _ in range(10)
] ]
prediction = BatDetect2Prediction(clip=clip, predictions=detections) prediction = ClipDetections(clip=clip, detections=detections)
path = tmp_path / "predictions" path = tmp_path / "predictions"
@ -52,7 +52,9 @@ def test_roundtrip(
assert recovered[0].clip == prediction.clip assert recovered[0].clip == prediction.clip
for recovered_prediction, detection in zip( for recovered_prediction, detection in zip(
recovered[0].predictions, detections recovered[0].detections,
detections,
strict=True,
): ):
assert ( assert (
recovered_prediction.detection_score == detection.detection_score recovered_prediction.detection_score == detection.detection_score

View File

View File

@ -0,0 +1,86 @@
from typing import Literal
import numpy as np
import pytest
from soundevent import data
from batdetect2.typing import Detection
@pytest.fixture
def clip(recording: data.Recording) -> data.Clip:
return data.Clip(recording=recording, start_time=0, end_time=100)
@pytest.fixture
def create_detection():
def factory(
detection_score: float = 0.5,
start_time: float = 0.1,
duration: float = 0.01,
low_freq: float = 40_000,
bandwidth: float = 30_000,
pip_score: float = 0,
myo_score: float = 0,
):
return Detection(
detection_score=detection_score,
class_scores=np.array(
[
pip_score,
myo_score,
]
),
features=np.zeros([32]),
geometry=data.BoundingBox(
coordinates=[
start_time,
low_freq,
start_time + duration,
low_freq + bandwidth,
]
),
)
return factory
@pytest.fixture
def create_annotation(
clip: data.Clip,
bat_tag: data.Tag,
myomyo_tag: data.Tag,
pippip_tag: data.Tag,
):
def factory(
start_time: float = 0.1,
duration: float = 0.01,
low_freq: float = 40_000,
bandwidth: float = 30_000,
is_target: bool = True,
class_name: Literal["pippip", "myomyo"] | None = None,
):
tags = [bat_tag] if is_target else []
if class_name is not None:
if class_name == "pippip":
tags.append(pippip_tag)
elif class_name == "myomyo":
tags.append(myomyo_tag)
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(
coordinates=[
start_time,
low_freq,
start_time + duration,
low_freq + bandwidth,
]
),
recording=clip.recording,
),
tags=tags,
)
return factory

View File

@ -0,0 +1,77 @@
import numpy as np
import pytest
from soundevent import data
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.typing import ClipDetections
from batdetect2.typing.targets import TargetProtocol
def test_classification(
clip: data.Clip,
sample_targets: TargetProtocol,
create_detection,
create_annotation,
):
config = ClassificationTaskConfig.model_validate(
{
"name": "sound_event_classification",
"metrics": [{"name": "average_precision"}],
}
)
evaluator = build_task(config, targets=sample_targets)
# Create a dummy prediction
prediction = ClipDetections(
clip=clip,
detections=[
create_detection(
start_time=1 + 0.1 * index,
pip_score=score,
)
for index, score in enumerate(np.linspace(0, 1, 100))
]
+ [
create_detection(
start_time=1.05 + 0.1 * index,
myo_score=score,
)
for index, score in enumerate(np.linspace(1, 0, 100))
],
)
# Create a dummy annotation
gt = data.ClipAnnotation(
clip=clip,
sound_events=[
create_annotation(
start_time=1 + 0.1 * index,
is_target=index % 2 == 0,
class_name="pippip",
)
for index in range(100)
]
+ [
create_annotation(
start_time=1.05 + 0.1 * index,
is_target=index % 3 == 0,
class_name="myomyo",
)
for index in range(100)
],
)
evals = evaluator.evaluate([gt], [prediction])
metrics = evaluator.compute_metrics(evals)
assert metrics["classification/average_precision/pippip"] == pytest.approx(
0.5, abs=0.005
)
assert metrics["classification/average_precision/myomyo"] == pytest.approx(
0.371, abs=0.005
)
assert metrics["classification/mean_average_precision"] == pytest.approx(
(0.5 + 0.371) / 2, abs=0.005
)

View File

@ -0,0 +1,50 @@
import numpy as np
import pytest
from soundevent import data
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.typing import ClipDetections
from batdetect2.typing.targets import TargetProtocol
def test_detection(
clip: data.Clip,
sample_targets: TargetProtocol,
create_detection,
create_annotation,
):
config = DetectionTaskConfig.model_validate(
{
"name": "sound_event_detection",
"metrics": [{"name": "average_precision"}],
}
)
evaluator = build_task(config, targets=sample_targets)
# Create a dummy prediction
prediction = ClipDetections(
clip=clip,
detections=[
create_detection(start_time=1 + 0.1 * index, detection_score=score)
for index, score in enumerate(np.linspace(0, 1, 100))
],
)
# Create a dummy annotation
gt = data.ClipAnnotation(
clip=clip,
sound_events=[
# Only half of the annotations are targets
create_annotation(
start_time=1 + 0.1 * index,
is_target=index % 2 == 0,
)
for index in range(100)
],
)
# Run the task
evals = evaluator.evaluate([gt], [prediction])
metrics = evaluator.compute_metrics(evals)
assert metrics["detection/average_precision"] == pytest.approx(0.5)

View File

@ -14,7 +14,7 @@ from batdetect2.postprocess.decoding import (
get_generic_tags, get_generic_tags,
get_prediction_features, get_prediction_features,
) )
from batdetect2.typing import RawPrediction, TargetProtocol from batdetect2.typing import Detection, TargetProtocol
@pytest.fixture @pytest.fixture
@ -209,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset:
@pytest.fixture @pytest.fixture
def sample_raw_predictions() -> List[RawPrediction]: def sample_raw_predictions() -> List[Detection]:
"""Manually crafted RawPrediction objects using the actual type.""" """Manually crafted RawPrediction objects using the actual type."""
pred1_classes = xr.DataArray( pred1_classes = xr.DataArray(
@ -220,7 +220,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
coords={"feature": ["f0", "f1", "f2", "f3"]}, coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"], dims=["feature"],
) )
pred1 = RawPrediction( pred1 = Detection(
detection_score=0.9, detection_score=0.9,
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=[ coordinates=[
@ -242,7 +242,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
coords={"feature": ["f0", "f1", "f2", "f3"]}, coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"], dims=["feature"],
) )
pred2 = RawPrediction( pred2 = Detection(
detection_score=0.8, detection_score=0.8,
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=[ coordinates=[
@ -264,7 +264,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
coords={"feature": ["f0", "f1", "f2", "f3"]}, coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"], dims=["feature"],
) )
pred3 = RawPrediction( pred3 = Detection(
detection_score=0.15, detection_score=0.15,
geometry=data.BoundingBox( geometry=data.BoundingBox(
coordinates=[ coordinates=[
@ -281,7 +281,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
def test_convert_raw_to_sound_event_basic( def test_convert_raw_to_sound_event_basic(
sample_raw_predictions: List[RawPrediction], sample_raw_predictions: List[Detection],
sample_recording: data.Recording, sample_recording: data.Recording,
dummy_targets: TargetProtocol, dummy_targets: TargetProtocol,
): ):
@ -324,7 +324,7 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding( def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions: List[RawPrediction], sample_raw_predictions: List[Detection],
sample_recording: data.Recording, sample_recording: data.Recording,
dummy_targets: TargetProtocol, dummy_targets: TargetProtocol,
): ):
@ -352,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold( def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions: List[RawPrediction], sample_raw_predictions: List[Detection],
sample_recording: data.Recording, sample_recording: data.Recording,
dummy_targets: TargetProtocol, dummy_targets: TargetProtocol,
): ):
@ -380,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class( def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions: List[RawPrediction], sample_raw_predictions: List[Detection],
sample_recording: data.Recording, sample_recording: data.Recording,
dummy_targets: TargetProtocol, dummy_targets: TargetProtocol,
): ):
@ -407,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold( def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions: List[RawPrediction], sample_raw_predictions: List[Detection],
sample_recording: data.Recording, sample_recording: data.Recording,
dummy_targets: TargetProtocol, dummy_targets: TargetProtocol,
): ):
@ -433,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic( def test_convert_raw_list_to_clip_basic(
sample_raw_predictions: List[RawPrediction], sample_raw_predictions: List[Detection],
sample_clip: data.Clip, sample_clip: data.Clip,
dummy_targets: TargetProtocol, dummy_targets: TargetProtocol,
): ):
@ -485,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
def test_convert_raw_list_to_clip_passes_args( def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions: List[RawPrediction], sample_raw_predictions: List[Detection],
sample_clip: data.Clip, sample_clip: data.Clip,
dummy_targets: TargetProtocol, dummy_targets: TargetProtocol,
): ):

View File

@ -1,4 +1,3 @@
import numpy as np import numpy as np
import pytest import pytest
import xarray as xr import xarray as xr