Run formatter

This commit is contained in:
mbsantiago 2025-12-12 21:28:28 +00:00
parent 531ff69974
commit 0adb58e039
64 changed files with 520 additions and 244 deletions

View File

@ -32,11 +32,11 @@ from batdetect2.train import (
)
from batdetect2.typing import (
AudioLoader,
BatDetect2Prediction,
ClipDetections,
EvaluatorProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
RawPrediction,
Detection,
TargetProtocol,
)
@ -110,7 +110,7 @@ class BatDetect2API:
experiment_name: str | None = None,
run_name: str | None = None,
save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
) -> Tuple[Dict[str, float], List[List[Detection]]]:
return evaluate(
self.model,
test_annotations,
@ -128,7 +128,7 @@ class BatDetect2API:
def evaluate_predictions(
self,
annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
output_dir: data.PathLike | None = None,
):
clip_evals = self.evaluator.evaluate(
@ -170,24 +170,24 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> BatDetect2Prediction:
def process_file(self, audio_file: str) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav)
return BatDetect2Prediction(
return ClipDetections(
clip=data.Clip(
uuid=recording.uuid,
recording=recording,
start_time=0,
end_time=recording.duration,
),
predictions=detections,
detections=detections,
)
def process_audio(
self,
audio: np.ndarray,
) -> List[RawPrediction]:
) -> List[Detection]:
spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec)
@ -195,7 +195,7 @@ class BatDetect2API:
self,
spec: torch.Tensor,
start_time: float = 0,
) -> List[RawPrediction]:
) -> List[Detection]:
if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.")
@ -214,7 +214,7 @@ class BatDetect2API:
def process_directory(
self,
audio_dir: data.PathLike,
) -> List[BatDetect2Prediction]:
) -> List[ClipDetections]:
files = list(get_audio_files(audio_dir))
return self.process_files(files)
@ -222,7 +222,7 @@ class BatDetect2API:
self,
audio_files: Sequence[data.PathLike],
num_workers: int | None = None,
) -> List[BatDetect2Prediction]:
) -> List[ClipDetections]:
return process_file_list(
self.model,
audio_files,
@ -238,7 +238,7 @@ class BatDetect2API:
clips: Sequence[data.Clip],
batch_size: int | None = None,
num_workers: int | None = None,
) -> List[BatDetect2Prediction]:
) -> List[ClipDetections]:
return run_batch_inference(
self.model,
clips,
@ -252,7 +252,7 @@ class BatDetect2API:
def save_predictions(
self,
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
path: data.PathLike,
audio_dir: data.PathLike | None = None,
format: str | None = None,
@ -274,7 +274,7 @@ class BatDetect2API:
def load_predictions(
self,
path: data.PathLike,
) -> List[BatDetect2Prediction]:
) -> List[ClipDetections]:
return self.formatter.load(path)
@classmethod

View File

@ -1,4 +1,3 @@
import numpy as np
from numpy.typing import DTypeLike
from pydantic import Field

View File

@ -1,4 +1,3 @@
import numpy as np
import torch
import xarray as xr

View File

@ -264,7 +264,14 @@ class Not:
SoundEventConditionConfig = Annotated[
HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig,
HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| DurationConfig
| FrequencyConfig
| AllOfConfig
| AnyOfConfig
| NotConfig,
Field(discriminator="name"),
]

View File

@ -13,9 +13,9 @@ from batdetect2.data.predictions.base import (
)
from batdetect2.targets import terms
from batdetect2.typing import (
BatDetect2Prediction,
ClipDetections,
OutputFormatterProtocol,
RawPrediction,
Detection,
TargetProtocol,
)
@ -113,7 +113,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
self.annotation_note = annotation_note
def format(
self, predictions: Sequence[BatDetect2Prediction]
self, predictions: Sequence[ClipDetections]
) -> List[FileAnnotation]:
return [
self.format_prediction(prediction) for prediction in predictions
@ -164,14 +164,12 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
return highest_scoring["class"]
def format_prediction(
self, prediction: BatDetect2Prediction
) -> FileAnnotation:
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
recording = prediction.clip.recording
annotations = [
self.format_sound_event_prediction(pred)
for pred in prediction.predictions
for pred in prediction.detections
]
return FileAnnotation(
@ -196,7 +194,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
) # type: ignore
def format_sound_event_prediction(
self, prediction: RawPrediction
self, prediction: Detection
) -> Annotation:
start_time, low_freq, end_time, high_freq = compute_bounds(
prediction.geometry

View File

@ -14,9 +14,9 @@ from batdetect2.data.predictions.base import (
prediction_formatters,
)
from batdetect2.typing import (
BatDetect2Prediction,
ClipDetections,
OutputFormatterProtocol,
RawPrediction,
Detection,
TargetProtocol,
)
@ -29,7 +29,7 @@ class ParquetOutputConfig(BaseConfig):
include_geometry: bool = True
class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
def __init__(
self,
targets: TargetProtocol,
@ -44,13 +44,13 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def format(
self,
predictions: Sequence[BatDetect2Prediction],
) -> List[BatDetect2Prediction]:
predictions: Sequence[ClipDetections],
) -> List[ClipDetections]:
return list(predictions)
def save(
self,
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
path: data.PathLike,
audio_dir: data.PathLike | None = None,
) -> None:
@ -61,10 +61,10 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
# Ensure the file has .parquet extension if it's a file path
if path.suffix != ".parquet":
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
if path.is_dir() or not path.suffix:
path = path / "predictions.parquet"
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
if path.is_dir() or not path.suffix:
path = path / "predictions.parquet"
rows = []
for prediction in predictions:
@ -73,12 +73,14 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
if audio_dir is not None:
recording = recording.model_copy(
update=dict(path=make_path_relative(recording.path, audio_dir))
update=dict(
path=make_path_relative(recording.path, audio_dir)
)
)
recording_json = recording.model_dump_json(exclude_none=True)
for pred in prediction.predictions:
for pred in prediction.detections:
row = {
"clip_uuid": str(clip.uuid),
"clip_start_time": clip.start_time,
@ -116,16 +118,16 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
logger.info(f"Saving {len(df)} predictions to {path}")
df.to_parquet(path, index=False)
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
def load(self, path: data.PathLike) -> List[ClipDetections]:
path = Path(path)
if path.is_dir():
# Try to find parquet files
files = list(path.glob("*.parquet"))
if not files:
return []
# Read all and concatenate
dfs = [pd.read_parquet(f) for f in files]
df = pd.concat(dfs, ignore_index=True)
# Try to find parquet files
files = list(path.glob("*.parquet"))
if not files:
return []
# Read all and concatenate
dfs = [pd.read_parquet(f) for f in files]
df = pd.concat(dfs, ignore_index=True)
else:
df = pd.read_parquet(path)
@ -135,17 +137,16 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
clip_uuid = row["clip_uuid"]
if clip_uuid not in predictions_by_clip:
recording = data.Recording.model_validate_json(row["recording_info"])
recording = data.Recording.model_validate_json(
row["recording_info"]
)
clip = data.Clip(
uuid=UUID(clip_uuid),
recording=recording,
start_time=row["clip_start_time"],
end_time=row["clip_end_time"],
)
predictions_by_clip[clip_uuid] = {
"clip": clip,
"preds": []
}
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
# Reconstruct geometry
if "geometry" in row and row["geometry"]:
@ -156,14 +157,20 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
row["start_time"],
row["low_freq"],
row["end_time"],
row["high_freq"]
row["high_freq"],
]
)
class_scores = np.array(row["class_scores"]) if "class_scores" in row else np.zeros(len(self.targets.class_names))
features = np.array(row["features"]) if "features" in row else np.zeros(0)
class_scores = (
np.array(row["class_scores"])
if "class_scores" in row
else np.zeros(len(self.targets.class_names))
)
features = (
np.array(row["features"]) if "features" in row else np.zeros(0)
)
pred = RawPrediction(
pred = Detection(
geometry=geometry,
detection_score=row["detection_score"],
class_scores=class_scores,
@ -174,9 +181,8 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
results = []
for clip_data in predictions_by_clip.values():
results.append(
BatDetect2Prediction(
clip=clip_data["clip"],
predictions=clip_data["preds"]
ClipDetections(
clip=clip_data["clip"], detections=clip_data["preds"]
)
)

View File

@ -15,9 +15,9 @@ from batdetect2.data.predictions.base import (
prediction_formatters,
)
from batdetect2.typing import (
BatDetect2Prediction,
ClipDetections,
OutputFormatterProtocol,
RawPrediction,
Detection,
TargetProtocol,
)
@ -30,7 +30,7 @@ class RawOutputConfig(BaseConfig):
include_geometry: bool = True
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
class RawFormatter(OutputFormatterProtocol[ClipDetections]):
def __init__(
self,
targets: TargetProtocol,
@ -47,13 +47,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def format(
self,
predictions: Sequence[BatDetect2Prediction],
) -> List[BatDetect2Prediction]:
predictions: Sequence[ClipDetections],
) -> List[ClipDetections]:
return list(predictions)
def save(
self,
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
path: data.PathLike,
audio_dir: data.PathLike | None = None,
) -> None:
@ -68,10 +68,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
dataset = self.pred_to_xr(prediction, audio_dir)
dataset.to_netcdf(path / f"{clip.uuid}.nc")
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
def load(self, path: data.PathLike) -> List[ClipDetections]:
path = Path(path)
files = list(path.glob("*.nc"))
predictions: List[BatDetect2Prediction] = []
predictions: List[ClipDetections] = []
for filepath in files:
logger.debug(f"Loading clip predictions {filepath}")
@ -83,7 +83,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def pred_to_xr(
self,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
audio_dir: data.PathLike | None = None,
) -> xr.Dataset:
clip = prediction.clip
@ -97,7 +97,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
data = defaultdict(list)
for pred in prediction.predictions:
for pred in prediction.detections:
detection_id = str(uuid4())
data["detection_id"].append(detection_id)
@ -167,7 +167,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
},
)
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
clip_data = dataset
clip_id = clip_data.clip_id.item()
@ -219,7 +219,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
features = np.zeros(0)
sound_events.append(
RawPrediction(
Detection(
geometry=geometry,
detection_score=score,
class_scores=class_scores,
@ -227,9 +227,9 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
)
)
return BatDetect2Prediction(
return ClipDetections(
clip=clip,
predictions=sound_events,
detections=sound_events,
)
@prediction_formatters.register(RawOutputConfig)

View File

@ -9,9 +9,9 @@ from batdetect2.data.predictions.base import (
prediction_formatters,
)
from batdetect2.typing import (
BatDetect2Prediction,
ClipDetections,
OutputFormatterProtocol,
RawPrediction,
Detection,
TargetProtocol,
)
@ -35,7 +35,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
def format(
self,
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
) -> List[data.ClipPrediction]:
return [
self.format_prediction(prediction) for prediction in predictions
@ -63,20 +63,20 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
def format_prediction(
self,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> data.ClipPrediction:
recording = prediction.clip.recording
return data.ClipPrediction(
clip=prediction.clip,
sound_events=[
self.format_sound_event_prediction(pred, recording)
for pred in prediction.predictions
for pred in prediction.detections
],
)
def format_sound_event_prediction(
self,
prediction: RawPrediction,
prediction: Detection,
recording: data.Recording,
) -> data.SoundEventPrediction:
return data.SoundEventPrediction(
@ -89,7 +89,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
)
def get_sound_event_tags(
self, prediction: RawPrediction
self, prediction: Detection
) -> List[data.PredictedTag]:
sorted_indices = np.argsort(prediction.class_scores)[::-1]

View File

@ -221,7 +221,11 @@ class ApplyAll:
SoundEventTransformConfig = Annotated[
SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig,
SetFrequencyBoundConfig
| ReplaceTagConfig
| MapTagValueConfig
| ApplyIfConfig
| ApplyAllConfig,
Field(discriminator="name"),
]

View File

@ -11,7 +11,7 @@ from soundevent.geometry import (
)
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import AffinityFunction, RawPrediction
from batdetect2.typing import AffinityFunction, Detection
affinity_functions: Registry[AffinityFunction, []] = Registry(
"affinity_function"
@ -42,7 +42,7 @@ class TimeAffinity(AffinityFunction):
def __call__(
self,
detection: RawPrediction,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
@ -77,7 +77,7 @@ class IntervalIOU(AffinityFunction):
def __call__(
self,
detection: RawPrediction,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
@ -120,7 +120,7 @@ class BBoxIOU(AffinityFunction):
def __call__(
self,
prediction: RawPrediction,
prediction: Detection,
gt: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry
@ -168,7 +168,7 @@ class GeometricIOU(AffinityFunction):
def __call__(
self,
prediction: RawPrediction,
prediction: Detection,
gt: data.SoundEventAnnotation,
):
target_geometry = gt.sound_event.geometry

View File

@ -12,7 +12,7 @@ from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing import Detection
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
@ -38,7 +38,7 @@ def evaluate(
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
) -> Tuple[Dict[str, float], List[List[Detection]]]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()

View File

@ -7,7 +7,7 @@ from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing.postprocess import ClipDetections
__all__ = [
"Evaluator",
@ -27,7 +27,7 @@ class Evaluator:
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
) -> List[Any]:
return [
task.evaluate(clip_annotations, predictions) for task in self.tasks

View File

@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing.postprocess import ClipDetections
class EvaluationModule(LightningModule):
@ -24,7 +24,7 @@ class EvaluationModule(LightningModule):
self.evaluator = evaluator
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[BatDetect2Prediction] = []
self.predictions: List[ClipDetections] = []
def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset()
@ -39,9 +39,9 @@ class EvaluationModule(LightningModule):
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
BatDetect2Prediction(
ClipDetections(
clip=clip_annotation.clip,
predictions=to_raw_predictions(
detections=to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
),

View File

@ -21,7 +21,7 @@ from batdetect2.evaluate.metrics.common import (
average_precision,
compute_precision_recall,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.typing import Detection, TargetProtocol
__all__ = [
"ClassificationMetric",
@ -35,7 +35,7 @@ __all__ = [
class MatchEval:
clip: data.Clip
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
pred: Detection | None
is_prediction: bool
is_ground_truth: bool

View File

@ -159,7 +159,10 @@ class ClipDetectionPrecision:
ClipDetectionMetricConfig = Annotated[
ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig,
ClipDetectionAveragePrecisionConfig
| ClipDetectionROCAUCConfig
| ClipDetectionRecallConfig
| ClipDetectionPrecisionConfig,
Field(discriminator="name"),
]

View File

@ -15,7 +15,7 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
from batdetect2.typing import Detection
__all__ = [
"DetectionMetricConfig",
@ -27,7 +27,7 @@ __all__ = [
@dataclass
class MatchEval:
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
pred: Detection | None
is_prediction: bool
is_ground_truth: bool

View File

@ -15,7 +15,7 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
from batdetect2.typing import Detection
from batdetect2.typing.targets import TargetProtocol
__all__ = [
@ -29,7 +29,7 @@ __all__ = [
class MatchEval:
clip: data.Clip
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
pred: Detection | None
is_ground_truth: bool
is_generic: bool
@ -298,7 +298,11 @@ class BalancedAccuracy:
TopClassMetricConfig = Annotated[
TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig,
TopClassAveragePrecisionConfig
| TopClassROCAUCConfig
| TopClassRecallConfig
| TopClassPrecisionConfig
| BalancedAccuracyConfig,
Field(discriminator="name"),
]

View File

@ -1,4 +1,3 @@
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

View File

@ -322,7 +322,10 @@ class ROCCurve(BasePlot):
ClassificationPlotConfig = Annotated[
PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig,
PRCurveConfig
| ROCCurveConfig
| ThresholdPrecisionCurveConfig
| ThresholdRecallCurveConfig,
Field(discriminator="name"),
]

View File

@ -290,7 +290,10 @@ class ExampleDetectionPlot(BasePlot):
DetectionPlotConfig = Annotated[
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig,
PRCurveConfig
| ROCCurveConfig
| ScoreDistributionPlotConfig
| ExampleDetectionPlotConfig,
Field(discriminator="name"),
]

View File

@ -346,7 +346,10 @@ class ExampleClassificationPlot(BasePlot):
TopClassPlotConfig = Annotated[
PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig,
PRCurveConfig
| ROCCurveConfig
| ConfusionMatrixConfig
| ExampleClassificationPlotConfig,
Field(discriminator="name"),
]
@ -403,7 +406,8 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
return matches
indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)], strict=False
*[(index, match.score) for index, match in enumerate(matches)],
strict=False,
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")

View File

@ -13,7 +13,7 @@ from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.targets import build_targets
from batdetect2.typing import (
BatDetect2Prediction,
ClipDetections,
EvaluatorProtocol,
TargetProtocol,
)
@ -45,7 +45,7 @@ def build_task(
def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
task: Optional["str"] = None,
targets: TargetProtocol | None = None,
config: TaskConfig | dict | None = None,

View File

@ -22,9 +22,9 @@ from batdetect2.evaluate.affinity import (
)
from batdetect2.typing import (
AffinityFunction,
BatDetect2Prediction,
ClipDetections,
EvaluatorProtocol,
RawPrediction,
Detection,
TargetProtocol,
)
@ -96,7 +96,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
@ -108,7 +108,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> T_Output: ...
def include_sound_event_annotation(
@ -128,7 +128,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def include_prediction(
self,
prediction: RawPrediction,
prediction: Detection,
clip: data.Clip,
) -> bool:
return is_in_bounds(

View File

@ -22,8 +22,8 @@ from batdetect2.evaluate.tasks.base import (
tasks_registry,
)
from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
ClipDetections,
Detection,
TargetProtocol,
)
@ -51,13 +51,13 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
preds = [
pred
for pred in prediction.predictions
for pred in prediction.detections
if self.include_prediction(pred, clip)
]
@ -140,7 +140,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
)
def get_class_score(pred: RawPrediction, class_idx: int) -> float:
def get_class_score(pred: Detection, class_idx: int) -> float:
return pred.class_scores[class_idx]

View File

@ -19,7 +19,7 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.typing import ClipDetections, TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig):
@ -37,7 +37,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -54,7 +54,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
gt_classes.add(class_name)
pred_scores = defaultdict(float)
for pred in prediction.predictions:
for pred in prediction.detections:
if not self.include_prediction(pred, clip):
continue

View File

@ -18,7 +18,7 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.typing import ClipDetections, TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig):
@ -36,7 +36,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -46,7 +46,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
)
pred_score = 0
for pred in prediction.predictions:
for pred in prediction.detections:
if not self.include_prediction(pred, clip):
continue

View File

@ -21,7 +21,7 @@ from batdetect2.evaluate.tasks.base import (
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing.postprocess import ClipDetections
class DetectionTaskConfig(BaseSEDTaskConfig):
@ -37,7 +37,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -48,7 +48,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
]
preds = [
pred
for pred in prediction.predictions
for pred in prediction.detections
if self.include_prediction(pred, clip)
]

View File

@ -20,7 +20,7 @@ from batdetect2.evaluate.tasks.base import (
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.typing import ClipDetections, TargetProtocol
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
@ -36,7 +36,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -47,7 +47,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
]
preds = [
pred
for pred in prediction.predictions
for pred in prediction.detections
if self.include_prediction(pred, clip)
]

View File

@ -88,7 +88,8 @@ def select_device(warn=True) -> str:
if warn:
warnings.warn(
"No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training.", stacklevel=2
"to speed up training.",
stacklevel=2,
)
return "cpu"

View File

@ -10,7 +10,7 @@ from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing.postprocess import ClipDetections
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
@ -30,7 +30,7 @@ def run_batch_inference(
config: Optional["BatDetect2Config"] = None,
num_workers: int | None = None,
batch_size: int | None = None,
) -> List[BatDetect2Prediction]:
) -> List[ClipDetections]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
@ -70,7 +70,7 @@ def process_file_list(
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: int | None = None,
) -> List[BatDetect2Prediction]:
) -> List[ClipDetections]:
clip_config = config.inference.clipping
clips = get_clips_from_files(
paths,

View File

@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing.postprocess import ClipDetections
class InferenceModule(LightningModule):
@ -19,7 +19,7 @@ class InferenceModule(LightningModule):
batch: DatasetItem,
batch_idx: int,
dataloader_idx: int = 0,
) -> Sequence[BatDetect2Prediction]:
) -> Sequence[ClipDetections]:
dataset = self.get_dataset()
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
@ -32,9 +32,9 @@ class InferenceModule(LightningModule):
)
predictions = [
BatDetect2Prediction(
ClipDetections(
clip=clip,
predictions=to_raw_predictions(
detections=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),

View File

@ -76,7 +76,10 @@ class MLFlowLoggerConfig(BaseLoggerConfig):
LoggerConfig = Annotated[
DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig,
DVCLiveConfig
| CSVLoggerConfig
| TensorBoardLoggerConfig
| MLFlowLoggerConfig,
Field(discriminator="name"),
]

View File

@ -1,4 +1,3 @@
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config

View File

@ -41,7 +41,10 @@ __all__ = [
]
DecoderLayerConfig = Annotated[
ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig,
ConvConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""

View File

@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
"""
import torch
from loguru import logger

View File

@ -43,7 +43,10 @@ __all__ = [
]
EncoderLayerConfig = Annotated[
ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig,
ConvConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""

View File

@ -1,4 +1,3 @@
from matplotlib import axes, patches
from soundevent.plot import plot_geometry

View File

@ -36,7 +36,9 @@ def plot_match_gallery(
sharey="row",
)
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False):
for tp_ax, tp_match in zip(
axes[0], true_positives[:n_examples], strict=False
):
try:
plot_true_positive_match(
tp_match,
@ -53,7 +55,9 @@ def plot_match_gallery(
):
continue
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False):
for fp_ax, fp_match in zip(
axes[1], false_positives[:n_examples], strict=False
):
try:
plot_false_positive_match(
fp_match,
@ -70,7 +74,9 @@ def plot_match_gallery(
):
continue
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False):
for fn_ax, fn_match in zip(
axes[2], false_negatives[:n_examples], strict=False
):
try:
plot_false_negative_match(
fn_match,
@ -87,7 +93,9 @@ def plot_match_gallery(
):
continue
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False):
for ct_ax, ct_match in zip(
axes[3], cross_triggers[:n_examples], strict=False
):
try:
plot_cross_trigger_match(
ct_match,

View File

@ -8,7 +8,7 @@ from batdetect2.plotting.clips import plot_clip
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
RawPrediction,
Detection,
)
__all__ = [
@ -22,7 +22,7 @@ __all__ = [
class MatchProtocol(Protocol):
clip: data.Clip
gt: data.SoundEventAnnotation | None
pred: RawPrediction | None
pred: Detection | None
score: float
true_class: str | None
@ -341,4 +341,3 @@ def plot_cross_trigger_match(
ax.set_title("Cross Trigger")
return ax

View File

@ -0,0 +1,6 @@
from batdetect2.typing import ClipDetections
class ClipTransform:
def __init__(self, clip: ClipDetections):
pass

View File

@ -1,4 +1,3 @@
from pydantic import Field
from soundevent import data

View File

@ -7,7 +7,7 @@ from soundevent import data
from batdetect2.typing.postprocess import (
ClipDetectionsArray,
RawPrediction,
Detection,
)
from batdetect2.typing.targets import TargetProtocol
@ -30,7 +30,7 @@ decoding.
def to_raw_predictions(
detections: ClipDetectionsArray,
targets: TargetProtocol,
) -> List[RawPrediction]:
) -> List[Detection]:
predictions = []
for score, class_scores, time, freq, dims, feats in zip(
@ -39,7 +39,8 @@ def to_raw_predictions(
detections.times,
detections.frequencies,
detections.sizes,
detections.features, strict=False,
detections.features,
strict=False,
):
highest_scoring_class = targets.class_names[class_scores.argmax()]
@ -50,7 +51,7 @@ def to_raw_predictions(
)
predictions.append(
RawPrediction(
Detection(
detection_score=score,
geometry=geom,
class_scores=class_scores,
@ -62,7 +63,7 @@ def to_raw_predictions(
def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction],
raw_predictions: List[Detection],
clip: data.Clip,
targets: TargetProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
@ -85,7 +86,7 @@ def convert_raw_predictions_to_clip_prediction(
def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
raw_prediction: Detection,
recording: data.Recording,
targets: TargetProtocol,
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,

View File

@ -1,4 +1,3 @@
import torch
from loguru import logger

View File

@ -377,7 +377,10 @@ class PeakNormalize(torch.nn.Module):
SpectrogramTransform = Annotated[
PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig,
PcenConfig
| ScaleAmplitudeConfig
| SpectralMeanSubstractionConfig
| PeakNormalizeConfig,
Field(discriminator="name"),
]

View File

@ -394,7 +394,8 @@ class MaskTime(torch.nn.Module):
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
(start, start + size)
for start, size in zip(mask_start, mask_size, strict=False)
]
return mask_time(spec, masks), clip_annotation
@ -460,7 +461,8 @@ class MaskFrequency(torch.nn.Module):
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
(start, start + size)
for start, size in zip(mask_start, mask_size, strict=False)
]
return mask_frequency(spec, masks), clip_annotation
@ -498,7 +500,12 @@ SpectrogramAugmentationConfig = Annotated[
]
AugmentationConfig = Annotated[
MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
MixAudioConfig
| AddEchoConfig
| ScaleVolumeConfig
| WarpConfig
| MaskFrequencyConfig
| MaskTimeConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of individual augmentation config."""

View File

@ -10,7 +10,7 @@ from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import (
BatDetect2Prediction,
ClipDetections,
EvaluatorProtocol,
ModelOutput,
TrainExample,
@ -24,7 +24,7 @@ class ValidationMetrics(Callback):
self.evaluator = evaluator
self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[BatDetect2Prediction] = []
self._predictions: List[ClipDetections] = []
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
dataloaders = trainer.val_dataloaders
@ -100,9 +100,9 @@ class ValidationMetrics(Callback):
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
BatDetect2Prediction(
ClipDetections(
clip=clip_annotation.clip,
predictions=to_raw_predictions(
detections=to_raw_predictions(
clip_dets.numpy(), targets=model.targets
),
)

View File

@ -1,4 +1,3 @@
from pydantic import Field
from soundevent import data

View File

@ -18,7 +18,6 @@ The primary entry points are:
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
"""
import numpy as np
import torch
import torch.nn.functional as F

View File

@ -10,11 +10,11 @@ from batdetect2.typing.evaluate import (
)
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
ClipDetections,
ClipDetectionsTensor,
GeometryDecoder,
PostprocessorProtocol,
RawPrediction,
Detection,
)
from batdetect2.typing.preprocess import (
AudioLoader,
@ -44,7 +44,7 @@ __all__ = [
"AudioLoader",
"Augmentation",
"BackboneModel",
"BatDetect2Prediction",
"ClipDetections",
"ClipDetectionsTensor",
"ClipLabeller",
"ClipMatches",
@ -65,7 +65,7 @@ __all__ = [
"PostprocessorProtocol",
"PreprocessorProtocol",
"ROITargetMapper",
"RawPrediction",
"Detection",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",

View File

@ -2,7 +2,7 @@ from typing import Generic, List, Protocol, Sequence, TypeVar
from soundevent.data import PathLike
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.typing.postprocess import ClipDetections
__all__ = [
"OutputFormatterProtocol",
@ -12,9 +12,7 @@ T = TypeVar("T")
class OutputFormatterProtocol(Protocol, Generic[T]):
def format(
self, predictions: Sequence[BatDetect2Prediction]
) -> List[T]: ...
def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ...
def save(
self,

View File

@ -13,7 +13,7 @@ from typing import (
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
from batdetect2.typing.postprocess import ClipDetections, Detection
from batdetect2.typing.targets import TargetProtocol
__all__ = [
@ -84,7 +84,7 @@ Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
class AffinityFunction(Protocol):
def __call__(
self,
detection: RawPrediction,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float: ...
@ -93,7 +93,7 @@ class MetricsProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
predictions: Sequence[Sequence[Detection]],
) -> Dict[str, float]: ...
@ -101,7 +101,7 @@ class PlotterProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
predictions: Sequence[Sequence[Detection]],
) -> Iterable[Tuple[str, Figure]]: ...
@ -114,7 +114,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ...
def compute_metrics(

View File

@ -22,7 +22,7 @@ from batdetect2.typing.models import ModelOutput
from batdetect2.typing.targets import Position, Size
__all__ = [
"RawPrediction",
"Detection",
"PostprocessorProtocol",
"GeometryDecoder",
]
@ -46,7 +46,8 @@ class GeometryDecoder(Protocol):
) -> data.Geometry: ...
class RawPrediction(NamedTuple):
@dataclass
class Detection:
geometry: data.Geometry
detection_score: float
class_scores: np.ndarray
@ -82,9 +83,16 @@ class ClipDetectionsTensor(NamedTuple):
@dataclass
class BatDetect2Prediction:
class ClipDetections:
clip: data.Clip
predictions: List[RawPrediction]
detections: List[Detection]
@dataclass
class ClipPrediction:
clip: data.Clip
detection_score: float
class_scores: np.ndarray
class PostprocessorProtocol(Protocol):

View File

@ -220,7 +220,8 @@ def get_annotations_from_preds(
predictions["high_freqs"],
class_ind_best,
class_prob_best,
predictions["det_probs"], strict=False,
predictions["det_probs"],
strict=False,
)
]
return annotations

View File

@ -87,9 +87,7 @@ def save_ann_spec(
y_extent = [0, duration, min_freq, max_freq]
plt.close("all")
plt.figure(
0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
)
plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100)
plt.imshow(
spec,
aspect="auto",

View File

@ -1,4 +1,3 @@
import numpy as np
import torch
from torch.nn import functional as F

View File

@ -10,8 +10,8 @@ from batdetect2.data.predictions import (
build_output_formatter,
)
from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
ClipDetections,
Detection,
TargetProtocol,
)
@ -31,7 +31,7 @@ def test_roundtrip(
tmp_path: Path,
):
detections = [
RawPrediction(
Detection(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
@ -44,7 +44,7 @@ def test_roundtrip(
for _ in range(10)
]
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
prediction = ClipDetections(clip=clip, detections=detections)
path = tmp_path / "predictions.parquet"
@ -58,7 +58,7 @@ def test_roundtrip(
assert recovered[0].clip == prediction.clip
for recovered_prediction, detection in zip(
recovered[0].predictions, detections
recovered[0].detections, detections, strict=True
):
assert (
recovered_prediction.detection_score == detection.detection_score
@ -81,7 +81,7 @@ def test_multiple_clips(
clip2 = clip.model_copy(update={"uuid": uuid4()})
detections1 = [
RawPrediction(
Detection(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
@ -94,7 +94,7 @@ def test_multiple_clips(
]
detections2 = [
RawPrediction(
Detection(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
@ -107,8 +107,8 @@ def test_multiple_clips(
]
predictions = [
BatDetect2Prediction(clip=clip, predictions=detections1),
BatDetect2Prediction(clip=clip2, predictions=detections2),
ClipDetections(clip=clip, detections=detections1),
ClipDetections(clip=clip2, detections=detections2),
]
path = tmp_path / "multi_predictions.parquet"
@ -133,16 +133,18 @@ def test_complex_geometry(
):
# Create a polygon geometry
polygon = data.Polygon(
coordinates=[[
[0.0, 10000.0],
[0.1, 20000.0],
[0.2, 10000.0],
[0.0, 10000.0],
]]
coordinates=[
[
[0.0, 10000.0],
[0.1, 20000.0],
[0.2, 10000.0],
[0.0, 10000.0],
]
]
)
detections = [
RawPrediction(
Detection(
geometry=polygon,
detection_score=0.95,
class_scores=np.random.uniform(
@ -152,7 +154,7 @@ def test_complex_geometry(
)
]
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
prediction = ClipDetections(clip=clip, detections=detections)
path = tmp_path / "complex_geometry.parquet"
sample_formatter.save(predictions=[prediction], path=path)
@ -160,9 +162,9 @@ def test_complex_geometry(
recovered = sample_formatter.load(path=path)
assert len(recovered) == 1
assert len(recovered[0].predictions) == 1
assert len(recovered[0].detections) == 1
recovered_pred = recovered[0].predictions[0]
recovered_pred = recovered[0].detections[0]
# Check if geometry is recovered correctly as a Polygon
assert isinstance(recovered_pred.geometry, data.Polygon)

View File

@ -6,8 +6,8 @@ from soundevent import data
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
ClipDetections,
Detection,
TargetProtocol,
)
@ -27,7 +27,7 @@ def test_roundtrip(
tmp_path: Path,
):
detections = [
RawPrediction(
Detection(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
@ -40,7 +40,7 @@ def test_roundtrip(
for _ in range(10)
]
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
prediction = ClipDetections(clip=clip, detections=detections)
path = tmp_path / "predictions"
@ -52,7 +52,9 @@ def test_roundtrip(
assert recovered[0].clip == prediction.clip
for recovered_prediction, detection in zip(
recovered[0].predictions, detections
recovered[0].detections,
detections,
strict=True,
):
assert (
recovered_prediction.detection_score == detection.detection_score

View File

View File

@ -0,0 +1,86 @@
from typing import Literal
import numpy as np
import pytest
from soundevent import data
from batdetect2.typing import Detection
@pytest.fixture
def clip(recording: data.Recording) -> data.Clip:
return data.Clip(recording=recording, start_time=0, end_time=100)
@pytest.fixture
def create_detection():
def factory(
detection_score: float = 0.5,
start_time: float = 0.1,
duration: float = 0.01,
low_freq: float = 40_000,
bandwidth: float = 30_000,
pip_score: float = 0,
myo_score: float = 0,
):
return Detection(
detection_score=detection_score,
class_scores=np.array(
[
pip_score,
myo_score,
]
),
features=np.zeros([32]),
geometry=data.BoundingBox(
coordinates=[
start_time,
low_freq,
start_time + duration,
low_freq + bandwidth,
]
),
)
return factory
@pytest.fixture
def create_annotation(
clip: data.Clip,
bat_tag: data.Tag,
myomyo_tag: data.Tag,
pippip_tag: data.Tag,
):
def factory(
start_time: float = 0.1,
duration: float = 0.01,
low_freq: float = 40_000,
bandwidth: float = 30_000,
is_target: bool = True,
class_name: Literal["pippip", "myomyo"] | None = None,
):
tags = [bat_tag] if is_target else []
if class_name is not None:
if class_name == "pippip":
tags.append(pippip_tag)
elif class_name == "myomyo":
tags.append(myomyo_tag)
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(
coordinates=[
start_time,
low_freq,
start_time + duration,
low_freq + bandwidth,
]
),
recording=clip.recording,
),
tags=tags,
)
return factory

View File

@ -0,0 +1,77 @@
import numpy as np
import pytest
from soundevent import data
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.typing import ClipDetections
from batdetect2.typing.targets import TargetProtocol
def test_classification(
clip: data.Clip,
sample_targets: TargetProtocol,
create_detection,
create_annotation,
):
config = ClassificationTaskConfig.model_validate(
{
"name": "sound_event_classification",
"metrics": [{"name": "average_precision"}],
}
)
evaluator = build_task(config, targets=sample_targets)
# Create a dummy prediction
prediction = ClipDetections(
clip=clip,
detections=[
create_detection(
start_time=1 + 0.1 * index,
pip_score=score,
)
for index, score in enumerate(np.linspace(0, 1, 100))
]
+ [
create_detection(
start_time=1.05 + 0.1 * index,
myo_score=score,
)
for index, score in enumerate(np.linspace(1, 0, 100))
],
)
# Create a dummy annotation
gt = data.ClipAnnotation(
clip=clip,
sound_events=[
create_annotation(
start_time=1 + 0.1 * index,
is_target=index % 2 == 0,
class_name="pippip",
)
for index in range(100)
]
+ [
create_annotation(
start_time=1.05 + 0.1 * index,
is_target=index % 3 == 0,
class_name="myomyo",
)
for index in range(100)
],
)
evals = evaluator.evaluate([gt], [prediction])
metrics = evaluator.compute_metrics(evals)
assert metrics["classification/average_precision/pippip"] == pytest.approx(
0.5, abs=0.005
)
assert metrics["classification/average_precision/myomyo"] == pytest.approx(
0.371, abs=0.005
)
assert metrics["classification/mean_average_precision"] == pytest.approx(
(0.5 + 0.371) / 2, abs=0.005
)

View File

@ -0,0 +1,50 @@
import numpy as np
import pytest
from soundevent import data
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.typing import ClipDetections
from batdetect2.typing.targets import TargetProtocol
def test_detection(
clip: data.Clip,
sample_targets: TargetProtocol,
create_detection,
create_annotation,
):
config = DetectionTaskConfig.model_validate(
{
"name": "sound_event_detection",
"metrics": [{"name": "average_precision"}],
}
)
evaluator = build_task(config, targets=sample_targets)
# Create a dummy prediction
prediction = ClipDetections(
clip=clip,
detections=[
create_detection(start_time=1 + 0.1 * index, detection_score=score)
for index, score in enumerate(np.linspace(0, 1, 100))
],
)
# Create a dummy annotation
gt = data.ClipAnnotation(
clip=clip,
sound_events=[
# Only half of the annotations are targets
create_annotation(
start_time=1 + 0.1 * index,
is_target=index % 2 == 0,
)
for index in range(100)
],
)
# Run the task
evals = evaluator.evaluate([gt], [prediction])
metrics = evaluator.compute_metrics(evals)
assert metrics["detection/average_precision"] == pytest.approx(0.5)

View File

@ -14,7 +14,7 @@ from batdetect2.postprocess.decoding import (
get_generic_tags,
get_prediction_features,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.typing import Detection, TargetProtocol
@pytest.fixture
@ -209,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset:
@pytest.fixture
def sample_raw_predictions() -> List[RawPrediction]:
def sample_raw_predictions() -> List[Detection]:
"""Manually crafted RawPrediction objects using the actual type."""
pred1_classes = xr.DataArray(
@ -220,7 +220,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred1 = RawPrediction(
pred1 = Detection(
detection_score=0.9,
geometry=data.BoundingBox(
coordinates=[
@ -242,7 +242,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred2 = RawPrediction(
pred2 = Detection(
detection_score=0.8,
geometry=data.BoundingBox(
coordinates=[
@ -264,7 +264,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred3 = RawPrediction(
pred3 = Detection(
detection_score=0.15,
geometry=data.BoundingBox(
coordinates=[
@ -281,7 +281,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions: List[RawPrediction],
sample_raw_predictions: List[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -324,7 +324,7 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions: List[RawPrediction],
sample_raw_predictions: List[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -352,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions: List[RawPrediction],
sample_raw_predictions: List[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -380,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions: List[RawPrediction],
sample_raw_predictions: List[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -407,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions: List[RawPrediction],
sample_raw_predictions: List[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -433,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions: List[RawPrediction],
sample_raw_predictions: List[Detection],
sample_clip: data.Clip,
dummy_targets: TargetProtocol,
):
@ -485,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions: List[RawPrediction],
sample_raw_predictions: List[Detection],
sample_clip: data.Clip,
dummy_targets: TargetProtocol,
):

View File

@ -1,4 +1,3 @@
import numpy as np
import pytest
import xarray as xr