mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Run formatter
This commit is contained in:
parent
531ff69974
commit
0adb58e039
@ -32,11 +32,11 @@ from batdetect2.train import (
|
|||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class BatDetect2API:
|
|||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
save_predictions: bool = True,
|
save_predictions: bool = True,
|
||||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
||||||
return evaluate(
|
return evaluate(
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -128,7 +128,7 @@ class BatDetect2API:
|
|||||||
def evaluate_predictions(
|
def evaluate_predictions(
|
||||||
self,
|
self,
|
||||||
annotations: Sequence[data.ClipAnnotation],
|
annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
output_dir: data.PathLike | None = None,
|
output_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
clip_evals = self.evaluator.evaluate(
|
clip_evals = self.evaluator.evaluate(
|
||||||
@ -170,24 +170,24 @@ class BatDetect2API:
|
|||||||
tensor = torch.tensor(audio).unsqueeze(0)
|
tensor = torch.tensor(audio).unsqueeze(0)
|
||||||
return self.preprocessor(tensor)
|
return self.preprocessor(tensor)
|
||||||
|
|
||||||
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
def process_file(self, audio_file: str) -> ClipDetections:
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
wav = self.audio_loader.load_recording(recording)
|
wav = self.audio_loader.load_recording(recording)
|
||||||
detections = self.process_audio(wav)
|
detections = self.process_audio(wav)
|
||||||
return BatDetect2Prediction(
|
return ClipDetections(
|
||||||
clip=data.Clip(
|
clip=data.Clip(
|
||||||
uuid=recording.uuid,
|
uuid=recording.uuid,
|
||||||
recording=recording,
|
recording=recording,
|
||||||
start_time=0,
|
start_time=0,
|
||||||
end_time=recording.duration,
|
end_time=recording.duration,
|
||||||
),
|
),
|
||||||
predictions=detections,
|
detections=detections,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_audio(
|
def process_audio(
|
||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
) -> List[RawPrediction]:
|
) -> List[Detection]:
|
||||||
spec = self.generate_spectrogram(audio)
|
spec = self.generate_spectrogram(audio)
|
||||||
return self.process_spectrogram(spec)
|
return self.process_spectrogram(spec)
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
) -> List[RawPrediction]:
|
) -> List[Detection]:
|
||||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||||
raise ValueError("Batched spectrograms not supported.")
|
raise ValueError("Batched spectrograms not supported.")
|
||||||
|
|
||||||
@ -214,7 +214,7 @@ class BatDetect2API:
|
|||||||
def process_directory(
|
def process_directory(
|
||||||
self,
|
self,
|
||||||
audio_dir: data.PathLike,
|
audio_dir: data.PathLike,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
files = list(get_audio_files(audio_dir))
|
files = list(get_audio_files(audio_dir))
|
||||||
return self.process_files(files)
|
return self.process_files(files)
|
||||||
|
|
||||||
@ -222,7 +222,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
audio_files: Sequence[data.PathLike],
|
audio_files: Sequence[data.PathLike],
|
||||||
num_workers: int | None = None,
|
num_workers: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
audio_files,
|
audio_files,
|
||||||
@ -238,7 +238,7 @@ class BatDetect2API:
|
|||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
clips,
|
clips,
|
||||||
@ -252,7 +252,7 @@ class BatDetect2API:
|
|||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
@ -274,7 +274,7 @@ class BatDetect2API:
|
|||||||
def load_predictions(
|
def load_predictions(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
return self.formatter.load(path)
|
return self.formatter.load(path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
|||||||
@ -264,7 +264,14 @@ class Not:
|
|||||||
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
SoundEventConditionConfig = Annotated[
|
||||||
HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig,
|
HasTagConfig
|
||||||
|
| HasAllTagsConfig
|
||||||
|
| HasAnyTagConfig
|
||||||
|
| DurationConfig
|
||||||
|
| FrequencyConfig
|
||||||
|
| AllOfConfig
|
||||||
|
| AnyOfConfig
|
||||||
|
| NotConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -13,9 +13,9 @@ from batdetect2.data.predictions.base import (
|
|||||||
)
|
)
|
||||||
from batdetect2.targets import terms
|
from batdetect2.targets import terms
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
self.annotation_note = annotation_note
|
self.annotation_note = annotation_note
|
||||||
|
|
||||||
def format(
|
def format(
|
||||||
self, predictions: Sequence[BatDetect2Prediction]
|
self, predictions: Sequence[ClipDetections]
|
||||||
) -> List[FileAnnotation]:
|
) -> List[FileAnnotation]:
|
||||||
return [
|
return [
|
||||||
self.format_prediction(prediction) for prediction in predictions
|
self.format_prediction(prediction) for prediction in predictions
|
||||||
@ -164,14 +164,12 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
||||||
return highest_scoring["class"]
|
return highest_scoring["class"]
|
||||||
|
|
||||||
def format_prediction(
|
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
|
||||||
self, prediction: BatDetect2Prediction
|
|
||||||
) -> FileAnnotation:
|
|
||||||
recording = prediction.clip.recording
|
recording = prediction.clip.recording
|
||||||
|
|
||||||
annotations = [
|
annotations = [
|
||||||
self.format_sound_event_prediction(pred)
|
self.format_sound_event_prediction(pred)
|
||||||
for pred in prediction.predictions
|
for pred in prediction.detections
|
||||||
]
|
]
|
||||||
|
|
||||||
return FileAnnotation(
|
return FileAnnotation(
|
||||||
@ -196,7 +194,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
def format_sound_event_prediction(
|
def format_sound_event_prediction(
|
||||||
self, prediction: RawPrediction
|
self, prediction: Detection
|
||||||
) -> Annotation:
|
) -> Annotation:
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
prediction.geometry
|
prediction.geometry
|
||||||
|
|||||||
@ -14,9 +14,9 @@ from batdetect2.data.predictions.base import (
|
|||||||
prediction_formatters,
|
prediction_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class ParquetOutputConfig(BaseConfig):
|
|||||||
include_geometry: bool = True
|
include_geometry: bool = True
|
||||||
|
|
||||||
|
|
||||||
class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
@ -44,13 +44,13 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
|
|
||||||
def format(
|
def format(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
return list(predictions)
|
return list(predictions)
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -73,12 +73,14 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
|
|
||||||
if audio_dir is not None:
|
if audio_dir is not None:
|
||||||
recording = recording.model_copy(
|
recording = recording.model_copy(
|
||||||
update=dict(path=make_path_relative(recording.path, audio_dir))
|
update=dict(
|
||||||
|
path=make_path_relative(recording.path, audio_dir)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
recording_json = recording.model_dump_json(exclude_none=True)
|
recording_json = recording.model_dump_json(exclude_none=True)
|
||||||
|
|
||||||
for pred in prediction.predictions:
|
for pred in prediction.detections:
|
||||||
row = {
|
row = {
|
||||||
"clip_uuid": str(clip.uuid),
|
"clip_uuid": str(clip.uuid),
|
||||||
"clip_start_time": clip.start_time,
|
"clip_start_time": clip.start_time,
|
||||||
@ -116,7 +118,7 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
logger.info(f"Saving {len(df)} predictions to {path}")
|
logger.info(f"Saving {len(df)} predictions to {path}")
|
||||||
df.to_parquet(path, index=False)
|
df.to_parquet(path, index=False)
|
||||||
|
|
||||||
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
# Try to find parquet files
|
# Try to find parquet files
|
||||||
@ -135,17 +137,16 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
clip_uuid = row["clip_uuid"]
|
clip_uuid = row["clip_uuid"]
|
||||||
|
|
||||||
if clip_uuid not in predictions_by_clip:
|
if clip_uuid not in predictions_by_clip:
|
||||||
recording = data.Recording.model_validate_json(row["recording_info"])
|
recording = data.Recording.model_validate_json(
|
||||||
|
row["recording_info"]
|
||||||
|
)
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
uuid=UUID(clip_uuid),
|
uuid=UUID(clip_uuid),
|
||||||
recording=recording,
|
recording=recording,
|
||||||
start_time=row["clip_start_time"],
|
start_time=row["clip_start_time"],
|
||||||
end_time=row["clip_end_time"],
|
end_time=row["clip_end_time"],
|
||||||
)
|
)
|
||||||
predictions_by_clip[clip_uuid] = {
|
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
|
||||||
"clip": clip,
|
|
||||||
"preds": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Reconstruct geometry
|
# Reconstruct geometry
|
||||||
if "geometry" in row and row["geometry"]:
|
if "geometry" in row and row["geometry"]:
|
||||||
@ -156,14 +157,20 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
row["start_time"],
|
row["start_time"],
|
||||||
row["low_freq"],
|
row["low_freq"],
|
||||||
row["end_time"],
|
row["end_time"],
|
||||||
row["high_freq"]
|
row["high_freq"],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
class_scores = np.array(row["class_scores"]) if "class_scores" in row else np.zeros(len(self.targets.class_names))
|
class_scores = (
|
||||||
features = np.array(row["features"]) if "features" in row else np.zeros(0)
|
np.array(row["class_scores"])
|
||||||
|
if "class_scores" in row
|
||||||
|
else np.zeros(len(self.targets.class_names))
|
||||||
|
)
|
||||||
|
features = (
|
||||||
|
np.array(row["features"]) if "features" in row else np.zeros(0)
|
||||||
|
)
|
||||||
|
|
||||||
pred = RawPrediction(
|
pred = Detection(
|
||||||
geometry=geometry,
|
geometry=geometry,
|
||||||
detection_score=row["detection_score"],
|
detection_score=row["detection_score"],
|
||||||
class_scores=class_scores,
|
class_scores=class_scores,
|
||||||
@ -174,9 +181,8 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
results = []
|
results = []
|
||||||
for clip_data in predictions_by_clip.values():
|
for clip_data in predictions_by_clip.values():
|
||||||
results.append(
|
results.append(
|
||||||
BatDetect2Prediction(
|
ClipDetections(
|
||||||
clip=clip_data["clip"],
|
clip=clip_data["clip"], detections=clip_data["preds"]
|
||||||
predictions=clip_data["preds"]
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -15,9 +15,9 @@ from batdetect2.data.predictions.base import (
|
|||||||
prediction_formatters,
|
prediction_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class RawOutputConfig(BaseConfig):
|
|||||||
include_geometry: bool = True
|
include_geometry: bool = True
|
||||||
|
|
||||||
|
|
||||||
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
@ -47,13 +47,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
|
|
||||||
def format(
|
def format(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
return list(predictions)
|
return list(predictions)
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -68,10 +68,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
dataset = self.pred_to_xr(prediction, audio_dir)
|
dataset = self.pred_to_xr(prediction, audio_dir)
|
||||||
dataset.to_netcdf(path / f"{clip.uuid}.nc")
|
dataset.to_netcdf(path / f"{clip.uuid}.nc")
|
||||||
|
|
||||||
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
files = list(path.glob("*.nc"))
|
files = list(path.glob("*.nc"))
|
||||||
predictions: List[BatDetect2Prediction] = []
|
predictions: List[ClipDetections] = []
|
||||||
|
|
||||||
for filepath in files:
|
for filepath in files:
|
||||||
logger.debug(f"Loading clip predictions {filepath}")
|
logger.debug(f"Loading clip predictions {filepath}")
|
||||||
@ -83,7 +83,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
|
|
||||||
def pred_to_xr(
|
def pred_to_xr(
|
||||||
self,
|
self,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: data.PathLike | None = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
clip = prediction.clip
|
clip = prediction.clip
|
||||||
@ -97,7 +97,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
|
|
||||||
data = defaultdict(list)
|
data = defaultdict(list)
|
||||||
|
|
||||||
for pred in prediction.predictions:
|
for pred in prediction.detections:
|
||||||
detection_id = str(uuid4())
|
detection_id = str(uuid4())
|
||||||
|
|
||||||
data["detection_id"].append(detection_id)
|
data["detection_id"].append(detection_id)
|
||||||
@ -167,7 +167,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
|
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
|
||||||
clip_data = dataset
|
clip_data = dataset
|
||||||
clip_id = clip_data.clip_id.item()
|
clip_id = clip_data.clip_id.item()
|
||||||
|
|
||||||
@ -219,7 +219,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
features = np.zeros(0)
|
features = np.zeros(0)
|
||||||
|
|
||||||
sound_events.append(
|
sound_events.append(
|
||||||
RawPrediction(
|
Detection(
|
||||||
geometry=geometry,
|
geometry=geometry,
|
||||||
detection_score=score,
|
detection_score=score,
|
||||||
class_scores=class_scores,
|
class_scores=class_scores,
|
||||||
@ -227,9 +227,9 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return BatDetect2Prediction(
|
return ClipDetections(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
predictions=sound_events,
|
detections=sound_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
@prediction_formatters.register(RawOutputConfig)
|
@prediction_formatters.register(RawOutputConfig)
|
||||||
|
|||||||
@ -9,9 +9,9 @@ from batdetect2.data.predictions.base import (
|
|||||||
prediction_formatters,
|
prediction_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
|
|
||||||
def format(
|
def format(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
) -> List[data.ClipPrediction]:
|
) -> List[data.ClipPrediction]:
|
||||||
return [
|
return [
|
||||||
self.format_prediction(prediction) for prediction in predictions
|
self.format_prediction(prediction) for prediction in predictions
|
||||||
@ -63,20 +63,20 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
|
|
||||||
def format_prediction(
|
def format_prediction(
|
||||||
self,
|
self,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
) -> data.ClipPrediction:
|
) -> data.ClipPrediction:
|
||||||
recording = prediction.clip.recording
|
recording = prediction.clip.recording
|
||||||
return data.ClipPrediction(
|
return data.ClipPrediction(
|
||||||
clip=prediction.clip,
|
clip=prediction.clip,
|
||||||
sound_events=[
|
sound_events=[
|
||||||
self.format_sound_event_prediction(pred, recording)
|
self.format_sound_event_prediction(pred, recording)
|
||||||
for pred in prediction.predictions
|
for pred in prediction.detections
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_sound_event_prediction(
|
def format_sound_event_prediction(
|
||||||
self,
|
self,
|
||||||
prediction: RawPrediction,
|
prediction: Detection,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
) -> data.SoundEventPrediction:
|
) -> data.SoundEventPrediction:
|
||||||
return data.SoundEventPrediction(
|
return data.SoundEventPrediction(
|
||||||
@ -89,7 +89,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_sound_event_tags(
|
def get_sound_event_tags(
|
||||||
self, prediction: RawPrediction
|
self, prediction: Detection
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
sorted_indices = np.argsort(prediction.class_scores)[::-1]
|
sorted_indices = np.argsort(prediction.class_scores)[::-1]
|
||||||
|
|
||||||
|
|||||||
@ -221,7 +221,11 @@ class ApplyAll:
|
|||||||
|
|
||||||
|
|
||||||
SoundEventTransformConfig = Annotated[
|
SoundEventTransformConfig = Annotated[
|
||||||
SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig,
|
SetFrequencyBoundConfig
|
||||||
|
| ReplaceTagConfig
|
||||||
|
| MapTagValueConfig
|
||||||
|
| ApplyIfConfig
|
||||||
|
| ApplyAllConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from soundevent.geometry import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.typing import AffinityFunction, RawPrediction
|
from batdetect2.typing import AffinityFunction, Detection
|
||||||
|
|
||||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||||
"affinity_function"
|
"affinity_function"
|
||||||
@ -42,7 +42,7 @@ class TimeAffinity(AffinityFunction):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
detection: RawPrediction,
|
detection: Detection,
|
||||||
ground_truth: data.SoundEventAnnotation,
|
ground_truth: data.SoundEventAnnotation,
|
||||||
) -> float:
|
) -> float:
|
||||||
target_geometry = ground_truth.sound_event.geometry
|
target_geometry = ground_truth.sound_event.geometry
|
||||||
@ -77,7 +77,7 @@ class IntervalIOU(AffinityFunction):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
detection: RawPrediction,
|
detection: Detection,
|
||||||
ground_truth: data.SoundEventAnnotation,
|
ground_truth: data.SoundEventAnnotation,
|
||||||
) -> float:
|
) -> float:
|
||||||
target_geometry = ground_truth.sound_event.geometry
|
target_geometry = ground_truth.sound_event.geometry
|
||||||
@ -120,7 +120,7 @@ class BBoxIOU(AffinityFunction):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prediction: RawPrediction,
|
prediction: Detection,
|
||||||
gt: data.SoundEventAnnotation,
|
gt: data.SoundEventAnnotation,
|
||||||
):
|
):
|
||||||
target_geometry = gt.sound_event.geometry
|
target_geometry = gt.sound_event.geometry
|
||||||
@ -168,7 +168,7 @@ class GeometricIOU(AffinityFunction):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prediction: RawPrediction,
|
prediction: Detection,
|
||||||
gt: data.SoundEventAnnotation,
|
gt: data.SoundEventAnnotation,
|
||||||
):
|
):
|
||||||
target_geometry = gt.sound_event.geometry
|
target_geometry = gt.sound_event.geometry
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from batdetect2.logging import build_logger
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
from batdetect2.typing import Detection
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
@ -38,7 +38,7 @@ def evaluate(
|
|||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
config = config or BatDetect2Config()
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from batdetect2.evaluate.config import EvaluationConfig
|
|||||||
from batdetect2.evaluate.tasks import build_task
|
from batdetect2.evaluate.tasks import build_task
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
@ -27,7 +27,7 @@ class Evaluator:
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
return [
|
return [
|
||||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import EvaluatorProtocol
|
from batdetect2.typing import EvaluatorProtocol
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
@ -24,7 +24,7 @@ class EvaluationModule(LightningModule):
|
|||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
|
||||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self.predictions: List[BatDetect2Prediction] = []
|
self.predictions: List[ClipDetections] = []
|
||||||
|
|
||||||
def test_step(self, batch: TestExample, batch_idx: int):
|
def test_step(self, batch: TestExample, batch_idx: int):
|
||||||
dataset = self.get_dataset()
|
dataset = self.get_dataset()
|
||||||
@ -39,9 +39,9 @@ class EvaluationModule(LightningModule):
|
|||||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
)
|
)
|
||||||
predictions = [
|
predictions = [
|
||||||
BatDetect2Prediction(
|
ClipDetections(
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
predictions=to_raw_predictions(
|
detections=to_raw_predictions(
|
||||||
clip_dets.numpy(),
|
clip_dets.numpy(),
|
||||||
targets=self.evaluator.targets,
|
targets=self.evaluator.targets,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from batdetect2.evaluate.metrics.common import (
|
|||||||
average_precision,
|
average_precision,
|
||||||
compute_precision_recall,
|
compute_precision_recall,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import Detection, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassificationMetric",
|
"ClassificationMetric",
|
||||||
@ -35,7 +35,7 @@ __all__ = [
|
|||||||
class MatchEval:
|
class MatchEval:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: data.SoundEventAnnotation | None
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: RawPrediction | None
|
pred: Detection | None
|
||||||
|
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
|
|||||||
@ -159,7 +159,10 @@ class ClipDetectionPrecision:
|
|||||||
|
|
||||||
|
|
||||||
ClipDetectionMetricConfig = Annotated[
|
ClipDetectionMetricConfig = Annotated[
|
||||||
ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig,
|
ClipDetectionAveragePrecisionConfig
|
||||||
|
| ClipDetectionROCAUCConfig
|
||||||
|
| ClipDetectionRecallConfig
|
||||||
|
| ClipDetectionPrecisionConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.typing import RawPrediction
|
from batdetect2.typing import Detection
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DetectionMetricConfig",
|
"DetectionMetricConfig",
|
||||||
@ -27,7 +27,7 @@ __all__ = [
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
gt: data.SoundEventAnnotation | None
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: RawPrediction | None
|
pred: Detection | None
|
||||||
|
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.typing import RawPrediction
|
from batdetect2.typing import Detection
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -29,7 +29,7 @@ __all__ = [
|
|||||||
class MatchEval:
|
class MatchEval:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: data.SoundEventAnnotation | None
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: RawPrediction | None
|
pred: Detection | None
|
||||||
|
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
is_generic: bool
|
is_generic: bool
|
||||||
@ -298,7 +298,11 @@ class BalancedAccuracy:
|
|||||||
|
|
||||||
|
|
||||||
TopClassMetricConfig = Annotated[
|
TopClassMetricConfig = Annotated[
|
||||||
TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig,
|
TopClassAveragePrecisionConfig
|
||||||
|
| TopClassROCAUCConfig
|
||||||
|
| TopClassRecallConfig
|
||||||
|
| TopClassPrecisionConfig
|
||||||
|
| BalancedAccuracyConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
|||||||
@ -322,7 +322,10 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
ClassificationPlotConfig = Annotated[
|
ClassificationPlotConfig = Annotated[
|
||||||
PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig,
|
PRCurveConfig
|
||||||
|
| ROCCurveConfig
|
||||||
|
| ThresholdPrecisionCurveConfig
|
||||||
|
| ThresholdRecallCurveConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -290,7 +290,10 @@ class ExampleDetectionPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
DetectionPlotConfig = Annotated[
|
DetectionPlotConfig = Annotated[
|
||||||
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig,
|
PRCurveConfig
|
||||||
|
| ROCCurveConfig
|
||||||
|
| ScoreDistributionPlotConfig
|
||||||
|
| ExampleDetectionPlotConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -346,7 +346,10 @@ class ExampleClassificationPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
TopClassPlotConfig = Annotated[
|
TopClassPlotConfig = Annotated[
|
||||||
PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig,
|
PRCurveConfig
|
||||||
|
| ROCCurveConfig
|
||||||
|
| ConfusionMatrixConfig
|
||||||
|
| ExampleClassificationPlotConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -403,7 +406,8 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
|||||||
return matches
|
return matches
|
||||||
|
|
||||||
indices, pred_scores = zip(
|
indices, pred_scores = zip(
|
||||||
*[(index, match.score) for index, match in enumerate(matches)], strict=False
|
*[(index, match.score) for index, match in enumerate(matches)],
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
|||||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
@ -45,7 +45,7 @@ def build_task(
|
|||||||
|
|
||||||
def evaluate_task(
|
def evaluate_task(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
task: Optional["str"] = None,
|
task: Optional["str"] = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
config: TaskConfig | dict | None = None,
|
config: TaskConfig | dict | None = None,
|
||||||
|
|||||||
@ -22,9 +22,9 @@ from batdetect2.evaluate.affinity import (
|
|||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AffinityFunction,
|
AffinityFunction,
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
) -> List[T_Output]:
|
) -> List[T_Output]:
|
||||||
return [
|
return [
|
||||||
self.evaluate_clip(clip_annotation, preds)
|
self.evaluate_clip(clip_annotation, preds)
|
||||||
@ -108,7 +108,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
) -> T_Output: ...
|
) -> T_Output: ...
|
||||||
|
|
||||||
def include_sound_event_annotation(
|
def include_sound_event_annotation(
|
||||||
@ -128,7 +128,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
|
|
||||||
def include_prediction(
|
def include_prediction(
|
||||||
self,
|
self,
|
||||||
prediction: RawPrediction,
|
prediction: Detection,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return is_in_bounds(
|
return is_in_bounds(
|
||||||
|
|||||||
@ -22,8 +22,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -51,13 +51,13 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
preds = [
|
preds = [
|
||||||
pred
|
pred
|
||||||
for pred in prediction.predictions
|
for pred in prediction.detections
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_class_score(pred: RawPrediction, class_idx: int) -> float:
|
def get_class_score(pred: Detection, class_idx: int) -> float:
|
||||||
return pred.class_scores[class_idx]
|
return pred.class_scores[class_idx]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
@ -37,7 +37,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
gt_classes.add(class_name)
|
gt_classes.add(class_name)
|
||||||
|
|
||||||
pred_scores = defaultdict(float)
|
pred_scores = defaultdict(float)
|
||||||
for pred in prediction.predictions:
|
for pred in prediction.detections:
|
||||||
if not self.include_prediction(pred, clip):
|
if not self.include_prediction(pred, clip):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
@ -36,7 +36,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pred_score = 0
|
pred_score = 0
|
||||||
for pred in prediction.predictions:
|
for pred in prediction.detections:
|
||||||
if not self.include_prediction(pred, clip):
|
if not self.include_prediction(pred, clip):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
@ -37,7 +37,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
]
|
]
|
||||||
preds = [
|
preds = [
|
||||||
pred
|
pred
|
||||||
for pred in prediction.predictions
|
for pred in prediction.detections
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseSEDTaskConfig,
|
BaseSEDTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
from batdetect2.typing import ClipDetections, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
@ -36,7 +36,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: BatDetect2Prediction,
|
prediction: ClipDetections,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
]
|
]
|
||||||
preds = [
|
preds = [
|
||||||
pred
|
pred
|
||||||
for pred in prediction.predictions
|
for pred in prediction.detections
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -88,7 +88,8 @@ def select_device(warn=True) -> str:
|
|||||||
if warn:
|
if warn:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||||
"to speed up training.", stacklevel=2
|
"to speed up training.",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.inference.lightning import InferenceModule
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||||
from batdetect2.targets.targets import build_targets
|
from batdetect2.targets.targets import build_targets
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
@ -30,7 +30,7 @@ def run_batch_inference(
|
|||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int | None = None,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
config = config or BatDetect2Config()
|
||||||
@ -70,7 +70,7 @@ def process_file_list(
|
|||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int | None = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[ClipDetections]:
|
||||||
clip_config = config.inference.clipping
|
clip_config = config.inference.clipping
|
||||||
clips = get_clips_from_files(
|
clips = get_clips_from_files(
|
||||||
paths,
|
paths,
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
|
|||||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
@ -19,7 +19,7 @@ class InferenceModule(LightningModule):
|
|||||||
batch: DatasetItem,
|
batch: DatasetItem,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> Sequence[BatDetect2Prediction]:
|
) -> Sequence[ClipDetections]:
|
||||||
dataset = self.get_dataset()
|
dataset = self.get_dataset()
|
||||||
|
|
||||||
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
||||||
@ -32,9 +32,9 @@ class InferenceModule(LightningModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
predictions = [
|
predictions = [
|
||||||
BatDetect2Prediction(
|
ClipDetections(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
predictions=to_raw_predictions(
|
detections=to_raw_predictions(
|
||||||
clip_dets.numpy(),
|
clip_dets.numpy(),
|
||||||
targets=self.model.targets,
|
targets=self.model.targets,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -76,7 +76,10 @@ class MLFlowLoggerConfig(BaseLoggerConfig):
|
|||||||
|
|
||||||
|
|
||||||
LoggerConfig = Annotated[
|
LoggerConfig = Annotated[
|
||||||
DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig,
|
DVCLiveConfig
|
||||||
|
| CSVLoggerConfig
|
||||||
|
| TensorBoardLoggerConfig
|
||||||
|
| MLFlowLoggerConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
|||||||
@ -41,7 +41,10 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
DecoderLayerConfig = Annotated[
|
DecoderLayerConfig = Annotated[
|
||||||
ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig,
|
ConvConfig
|
||||||
|
| FreqCoordConvUpConfig
|
||||||
|
| StandardConvUpConfig
|
||||||
|
| LayerGroupConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|||||||
@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
|||||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,10 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
EncoderLayerConfig = Annotated[
|
EncoderLayerConfig = Annotated[
|
||||||
ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig,
|
ConvConfig
|
||||||
|
| FreqCoordConvDownConfig
|
||||||
|
| StandardConvDownConfig
|
||||||
|
| LayerGroupConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from matplotlib import axes, patches
|
from matplotlib import axes, patches
|
||||||
from soundevent.plot import plot_geometry
|
from soundevent.plot import plot_geometry
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,9 @@ def plot_match_gallery(
|
|||||||
sharey="row",
|
sharey="row",
|
||||||
)
|
)
|
||||||
|
|
||||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False):
|
for tp_ax, tp_match in zip(
|
||||||
|
axes[0], true_positives[:n_examples], strict=False
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
plot_true_positive_match(
|
plot_true_positive_match(
|
||||||
tp_match,
|
tp_match,
|
||||||
@ -53,7 +55,9 @@ def plot_match_gallery(
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False):
|
for fp_ax, fp_match in zip(
|
||||||
|
axes[1], false_positives[:n_examples], strict=False
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
plot_false_positive_match(
|
plot_false_positive_match(
|
||||||
fp_match,
|
fp_match,
|
||||||
@ -70,7 +74,9 @@ def plot_match_gallery(
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False):
|
for fn_ax, fn_match in zip(
|
||||||
|
axes[2], false_negatives[:n_examples], strict=False
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
plot_false_negative_match(
|
plot_false_negative_match(
|
||||||
fn_match,
|
fn_match,
|
||||||
@ -87,7 +93,9 @@ def plot_match_gallery(
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False):
|
for ct_ax, ct_match in zip(
|
||||||
|
axes[3], cross_triggers[:n_examples], strict=False
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
plot_cross_trigger_match(
|
plot_cross_trigger_match(
|
||||||
ct_match,
|
ct_match,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from batdetect2.plotting.clips import plot_clip
|
|||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -22,7 +22,7 @@ __all__ = [
|
|||||||
class MatchProtocol(Protocol):
|
class MatchProtocol(Protocol):
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: data.SoundEventAnnotation | None
|
gt: data.SoundEventAnnotation | None
|
||||||
pred: RawPrediction | None
|
pred: Detection | None
|
||||||
score: float
|
score: float
|
||||||
true_class: str | None
|
true_class: str | None
|
||||||
|
|
||||||
@ -341,4 +341,3 @@ def plot_cross_trigger_match(
|
|||||||
ax.set_title("Cross Trigger")
|
ax.set_title("Cross Trigger")
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|||||||
6
src/batdetect2/postprocess/clips.py
Normal file
6
src/batdetect2/postprocess/clips.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from batdetect2.typing import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
|
class ClipTransform:
|
||||||
|
def __init__(self, clip: ClipDetections):
|
||||||
|
pass
|
||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
ClipDetectionsArray,
|
ClipDetectionsArray,
|
||||||
RawPrediction,
|
Detection,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ decoding.
|
|||||||
def to_raw_predictions(
|
def to_raw_predictions(
|
||||||
detections: ClipDetectionsArray,
|
detections: ClipDetectionsArray,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> List[RawPrediction]:
|
) -> List[Detection]:
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
for score, class_scores, time, freq, dims, feats in zip(
|
for score, class_scores, time, freq, dims, feats in zip(
|
||||||
@ -39,7 +39,8 @@ def to_raw_predictions(
|
|||||||
detections.times,
|
detections.times,
|
||||||
detections.frequencies,
|
detections.frequencies,
|
||||||
detections.sizes,
|
detections.sizes,
|
||||||
detections.features, strict=False,
|
detections.features,
|
||||||
|
strict=False,
|
||||||
):
|
):
|
||||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ def to_raw_predictions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
predictions.append(
|
predictions.append(
|
||||||
RawPrediction(
|
Detection(
|
||||||
detection_score=score,
|
detection_score=score,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
class_scores=class_scores,
|
class_scores=class_scores,
|
||||||
@ -62,7 +63,7 @@ def to_raw_predictions(
|
|||||||
|
|
||||||
|
|
||||||
def convert_raw_predictions_to_clip_prediction(
|
def convert_raw_predictions_to_clip_prediction(
|
||||||
raw_predictions: List[RawPrediction],
|
raw_predictions: List[Detection],
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
@ -85,7 +86,7 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
|
|
||||||
|
|
||||||
def convert_raw_prediction_to_sound_event_prediction(
|
def convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction: RawPrediction,
|
raw_prediction: Detection,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|||||||
@ -377,7 +377,10 @@ class PeakNormalize(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
SpectrogramTransform = Annotated[
|
SpectrogramTransform = Annotated[
|
||||||
PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig,
|
PcenConfig
|
||||||
|
| ScaleAmplitudeConfig
|
||||||
|
| SpectralMeanSubstractionConfig
|
||||||
|
| PeakNormalizeConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -394,7 +394,8 @@ class MaskTime(torch.nn.Module):
|
|||||||
size=num_masks,
|
size=num_masks,
|
||||||
)
|
)
|
||||||
masks = [
|
masks = [
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
(start, start + size)
|
||||||
|
for start, size in zip(mask_start, mask_size, strict=False)
|
||||||
]
|
]
|
||||||
return mask_time(spec, masks), clip_annotation
|
return mask_time(spec, masks), clip_annotation
|
||||||
|
|
||||||
@ -460,7 +461,8 @@ class MaskFrequency(torch.nn.Module):
|
|||||||
size=num_masks,
|
size=num_masks,
|
||||||
)
|
)
|
||||||
masks = [
|
masks = [
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
(start, start + size)
|
||||||
|
for start, size in zip(mask_start, mask_size, strict=False)
|
||||||
]
|
]
|
||||||
return mask_frequency(spec, masks), clip_annotation
|
return mask_frequency(spec, masks), clip_annotation
|
||||||
|
|
||||||
@ -498,7 +500,12 @@ SpectrogramAugmentationConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig,
|
MixAudioConfig
|
||||||
|
| AddEchoConfig
|
||||||
|
| ScaleVolumeConfig
|
||||||
|
| WarpConfig
|
||||||
|
| MaskFrequencyConfig
|
||||||
|
| MaskTimeConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of individual augmentation config."""
|
"""Type alias for the discriminated union of individual augmentation config."""
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.postprocess import to_raw_predictions
|
|||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
@ -24,7 +24,7 @@ class ValidationMetrics(Callback):
|
|||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self._predictions: List[BatDetect2Prediction] = []
|
self._predictions: List[ClipDetections] = []
|
||||||
|
|
||||||
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
|
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
@ -100,9 +100,9 @@ class ValidationMetrics(Callback):
|
|||||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
)
|
)
|
||||||
predictions = [
|
predictions = [
|
||||||
BatDetect2Prediction(
|
ClipDetections(
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
predictions=to_raw_predictions(
|
detections=to_raw_predictions(
|
||||||
clip_dets.numpy(), targets=model.targets
|
clip_dets.numpy(), targets=model.targets
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,6 @@ The primary entry points are:
|
|||||||
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|||||||
@ -10,11 +10,11 @@ from batdetect2.typing.evaluate import (
|
|||||||
)
|
)
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
GeometryDecoder,
|
GeometryDecoder,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
RawPrediction,
|
Detection,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import (
|
from batdetect2.typing.preprocess import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
@ -44,7 +44,7 @@ __all__ = [
|
|||||||
"AudioLoader",
|
"AudioLoader",
|
||||||
"Augmentation",
|
"Augmentation",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
"BatDetect2Prediction",
|
"ClipDetections",
|
||||||
"ClipDetectionsTensor",
|
"ClipDetectionsTensor",
|
||||||
"ClipLabeller",
|
"ClipLabeller",
|
||||||
"ClipMatches",
|
"ClipMatches",
|
||||||
@ -65,7 +65,7 @@ __all__ = [
|
|||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"PreprocessorProtocol",
|
"PreprocessorProtocol",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"RawPrediction",
|
"Detection",
|
||||||
"Size",
|
"Size",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from typing import Generic, List, Protocol, Sequence, TypeVar
|
|||||||
|
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OutputFormatterProtocol",
|
"OutputFormatterProtocol",
|
||||||
@ -12,9 +12,7 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
class OutputFormatterProtocol(Protocol, Generic[T]):
|
class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||||
def format(
|
def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ...
|
||||||
self, predictions: Sequence[BatDetect2Prediction]
|
|
||||||
) -> List[T]: ...
|
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from typing import (
|
|||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
from batdetect2.typing.postprocess import ClipDetections, Detection
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -84,7 +84,7 @@ Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True)
|
|||||||
class AffinityFunction(Protocol):
|
class AffinityFunction(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
detection: RawPrediction,
|
detection: Detection,
|
||||||
ground_truth: data.SoundEventAnnotation,
|
ground_truth: data.SoundEventAnnotation,
|
||||||
) -> float: ...
|
) -> float: ...
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ class MetricsProtocol(Protocol):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
predictions: Sequence[Sequence[Detection]],
|
||||||
) -> Dict[str, float]: ...
|
) -> Dict[str, float]: ...
|
||||||
|
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class PlotterProtocol(Protocol):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
predictions: Sequence[Sequence[Detection]],
|
||||||
) -> Iterable[Tuple[str, Figure]]: ...
|
) -> Iterable[Tuple[str, Figure]]: ...
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[BatDetect2Prediction],
|
predictions: Sequence[ClipDetections],
|
||||||
) -> EvaluationOutput: ...
|
) -> EvaluationOutput: ...
|
||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from batdetect2.typing.models import ModelOutput
|
|||||||
from batdetect2.typing.targets import Position, Size
|
from batdetect2.typing.targets import Position, Size
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RawPrediction",
|
"Detection",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"GeometryDecoder",
|
"GeometryDecoder",
|
||||||
]
|
]
|
||||||
@ -46,7 +46,8 @@ class GeometryDecoder(Protocol):
|
|||||||
) -> data.Geometry: ...
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
@dataclass
|
||||||
|
class Detection:
|
||||||
geometry: data.Geometry
|
geometry: data.Geometry
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: np.ndarray
|
class_scores: np.ndarray
|
||||||
@ -82,9 +83,16 @@ class ClipDetectionsTensor(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatDetect2Prediction:
|
class ClipDetections:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
predictions: List[RawPrediction]
|
detections: List[Detection]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipPrediction:
|
||||||
|
clip: data.Clip
|
||||||
|
detection_score: float
|
||||||
|
class_scores: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
|
|||||||
@ -220,7 +220,8 @@ def get_annotations_from_preds(
|
|||||||
predictions["high_freqs"],
|
predictions["high_freqs"],
|
||||||
class_ind_best,
|
class_ind_best,
|
||||||
class_prob_best,
|
class_prob_best,
|
||||||
predictions["det_probs"], strict=False,
|
predictions["det_probs"],
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return annotations
|
return annotations
|
||||||
|
|||||||
@ -87,9 +87,7 @@ def save_ann_spec(
|
|||||||
y_extent = [0, duration, min_freq, max_freq]
|
y_extent = [0, duration, min_freq, max_freq]
|
||||||
|
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
plt.figure(
|
plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100)
|
||||||
0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
|
|
||||||
)
|
|
||||||
plt.imshow(
|
plt.imshow(
|
||||||
spec,
|
spec,
|
||||||
aspect="auto",
|
aspect="auto",
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|||||||
@ -10,8 +10,8 @@ from batdetect2.data.predictions import (
|
|||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ def test_roundtrip(
|
|||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
detections = [
|
detections = [
|
||||||
RawPrediction(
|
Detection(
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=list(np.random.uniform(size=[4]))
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
),
|
),
|
||||||
@ -44,7 +44,7 @@ def test_roundtrip(
|
|||||||
for _ in range(10)
|
for _ in range(10)
|
||||||
]
|
]
|
||||||
|
|
||||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
prediction = ClipDetections(clip=clip, detections=detections)
|
||||||
|
|
||||||
path = tmp_path / "predictions.parquet"
|
path = tmp_path / "predictions.parquet"
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ def test_roundtrip(
|
|||||||
assert recovered[0].clip == prediction.clip
|
assert recovered[0].clip == prediction.clip
|
||||||
|
|
||||||
for recovered_prediction, detection in zip(
|
for recovered_prediction, detection in zip(
|
||||||
recovered[0].predictions, detections
|
recovered[0].detections, detections, strict=True
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
recovered_prediction.detection_score == detection.detection_score
|
recovered_prediction.detection_score == detection.detection_score
|
||||||
@ -81,7 +81,7 @@ def test_multiple_clips(
|
|||||||
clip2 = clip.model_copy(update={"uuid": uuid4()})
|
clip2 = clip.model_copy(update={"uuid": uuid4()})
|
||||||
|
|
||||||
detections1 = [
|
detections1 = [
|
||||||
RawPrediction(
|
Detection(
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=list(np.random.uniform(size=[4]))
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
),
|
),
|
||||||
@ -94,7 +94,7 @@ def test_multiple_clips(
|
|||||||
]
|
]
|
||||||
|
|
||||||
detections2 = [
|
detections2 = [
|
||||||
RawPrediction(
|
Detection(
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=list(np.random.uniform(size=[4]))
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
),
|
),
|
||||||
@ -107,8 +107,8 @@ def test_multiple_clips(
|
|||||||
]
|
]
|
||||||
|
|
||||||
predictions = [
|
predictions = [
|
||||||
BatDetect2Prediction(clip=clip, predictions=detections1),
|
ClipDetections(clip=clip, detections=detections1),
|
||||||
BatDetect2Prediction(clip=clip2, predictions=detections2),
|
ClipDetections(clip=clip2, detections=detections2),
|
||||||
]
|
]
|
||||||
|
|
||||||
path = tmp_path / "multi_predictions.parquet"
|
path = tmp_path / "multi_predictions.parquet"
|
||||||
@ -133,16 +133,18 @@ def test_complex_geometry(
|
|||||||
):
|
):
|
||||||
# Create a polygon geometry
|
# Create a polygon geometry
|
||||||
polygon = data.Polygon(
|
polygon = data.Polygon(
|
||||||
coordinates=[[
|
coordinates=[
|
||||||
|
[
|
||||||
[0.0, 10000.0],
|
[0.0, 10000.0],
|
||||||
[0.1, 20000.0],
|
[0.1, 20000.0],
|
||||||
[0.2, 10000.0],
|
[0.2, 10000.0],
|
||||||
[0.0, 10000.0],
|
[0.0, 10000.0],
|
||||||
]]
|
]
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
detections = [
|
detections = [
|
||||||
RawPrediction(
|
Detection(
|
||||||
geometry=polygon,
|
geometry=polygon,
|
||||||
detection_score=0.95,
|
detection_score=0.95,
|
||||||
class_scores=np.random.uniform(
|
class_scores=np.random.uniform(
|
||||||
@ -152,7 +154,7 @@ def test_complex_geometry(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
prediction = ClipDetections(clip=clip, detections=detections)
|
||||||
|
|
||||||
path = tmp_path / "complex_geometry.parquet"
|
path = tmp_path / "complex_geometry.parquet"
|
||||||
sample_formatter.save(predictions=[prediction], path=path)
|
sample_formatter.save(predictions=[prediction], path=path)
|
||||||
@ -160,9 +162,9 @@ def test_complex_geometry(
|
|||||||
recovered = sample_formatter.load(path=path)
|
recovered = sample_formatter.load(path=path)
|
||||||
|
|
||||||
assert len(recovered) == 1
|
assert len(recovered) == 1
|
||||||
assert len(recovered[0].predictions) == 1
|
assert len(recovered[0].detections) == 1
|
||||||
|
|
||||||
recovered_pred = recovered[0].predictions[0]
|
recovered_pred = recovered[0].detections[0]
|
||||||
|
|
||||||
# Check if geometry is recovered correctly as a Polygon
|
# Check if geometry is recovered correctly as a Polygon
|
||||||
assert isinstance(recovered_pred.geometry, data.Polygon)
|
assert isinstance(recovered_pred.geometry, data.Polygon)
|
||||||
|
|||||||
@ -6,8 +6,8 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
ClipDetections,
|
||||||
RawPrediction,
|
Detection,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ def test_roundtrip(
|
|||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
detections = [
|
detections = [
|
||||||
RawPrediction(
|
Detection(
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=list(np.random.uniform(size=[4]))
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
),
|
),
|
||||||
@ -40,7 +40,7 @@ def test_roundtrip(
|
|||||||
for _ in range(10)
|
for _ in range(10)
|
||||||
]
|
]
|
||||||
|
|
||||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
prediction = ClipDetections(clip=clip, detections=detections)
|
||||||
|
|
||||||
path = tmp_path / "predictions"
|
path = tmp_path / "predictions"
|
||||||
|
|
||||||
@ -52,7 +52,9 @@ def test_roundtrip(
|
|||||||
assert recovered[0].clip == prediction.clip
|
assert recovered[0].clip == prediction.clip
|
||||||
|
|
||||||
for recovered_prediction, detection in zip(
|
for recovered_prediction, detection in zip(
|
||||||
recovered[0].predictions, detections
|
recovered[0].detections,
|
||||||
|
detections,
|
||||||
|
strict=True,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
recovered_prediction.detection_score == detection.detection_score
|
recovered_prediction.detection_score == detection.detection_score
|
||||||
|
|||||||
0
tests/test_evaluate/__init__.py
Normal file
0
tests/test_evaluate/__init__.py
Normal file
0
tests/test_evaluate/test_tasks/__init__.py
Normal file
0
tests/test_evaluate/test_tasks/__init__.py
Normal file
86
tests/test_evaluate/test_tasks/conftest.py
Normal file
86
tests/test_evaluate/test_tasks/conftest.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.typing import Detection
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clip(recording: data.Recording) -> data.Clip:
|
||||||
|
return data.Clip(recording=recording, start_time=0, end_time=100)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_detection():
|
||||||
|
def factory(
|
||||||
|
detection_score: float = 0.5,
|
||||||
|
start_time: float = 0.1,
|
||||||
|
duration: float = 0.01,
|
||||||
|
low_freq: float = 40_000,
|
||||||
|
bandwidth: float = 30_000,
|
||||||
|
pip_score: float = 0,
|
||||||
|
myo_score: float = 0,
|
||||||
|
):
|
||||||
|
return Detection(
|
||||||
|
detection_score=detection_score,
|
||||||
|
class_scores=np.array(
|
||||||
|
[
|
||||||
|
pip_score,
|
||||||
|
myo_score,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
features=np.zeros([32]),
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
start_time,
|
||||||
|
low_freq,
|
||||||
|
start_time + duration,
|
||||||
|
low_freq + bandwidth,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_annotation(
|
||||||
|
clip: data.Clip,
|
||||||
|
bat_tag: data.Tag,
|
||||||
|
myomyo_tag: data.Tag,
|
||||||
|
pippip_tag: data.Tag,
|
||||||
|
):
|
||||||
|
def factory(
|
||||||
|
start_time: float = 0.1,
|
||||||
|
duration: float = 0.01,
|
||||||
|
low_freq: float = 40_000,
|
||||||
|
bandwidth: float = 30_000,
|
||||||
|
is_target: bool = True,
|
||||||
|
class_name: Literal["pippip", "myomyo"] | None = None,
|
||||||
|
):
|
||||||
|
tags = [bat_tag] if is_target else []
|
||||||
|
|
||||||
|
if class_name is not None:
|
||||||
|
if class_name == "pippip":
|
||||||
|
tags.append(pippip_tag)
|
||||||
|
elif class_name == "myomyo":
|
||||||
|
tags.append(myomyo_tag)
|
||||||
|
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
start_time,
|
||||||
|
low_freq,
|
||||||
|
start_time + duration,
|
||||||
|
low_freq + bandwidth,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
recording=clip.recording,
|
||||||
|
),
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
return factory
|
||||||
77
tests/test_evaluate/test_tasks/test_classification.py
Normal file
77
tests/test_evaluate/test_tasks/test_classification.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.tasks import build_task
|
||||||
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
|
from batdetect2.typing import ClipDetections
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def test_classification(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
create_detection,
|
||||||
|
create_annotation,
|
||||||
|
):
|
||||||
|
config = ClassificationTaskConfig.model_validate(
|
||||||
|
{
|
||||||
|
"name": "sound_event_classification",
|
||||||
|
"metrics": [{"name": "average_precision"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = build_task(config, targets=sample_targets)
|
||||||
|
|
||||||
|
# Create a dummy prediction
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
create_detection(
|
||||||
|
start_time=1 + 0.1 * index,
|
||||||
|
pip_score=score,
|
||||||
|
)
|
||||||
|
for index, score in enumerate(np.linspace(0, 1, 100))
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
create_detection(
|
||||||
|
start_time=1.05 + 0.1 * index,
|
||||||
|
myo_score=score,
|
||||||
|
)
|
||||||
|
for index, score in enumerate(np.linspace(1, 0, 100))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a dummy annotation
|
||||||
|
gt = data.ClipAnnotation(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=[
|
||||||
|
create_annotation(
|
||||||
|
start_time=1 + 0.1 * index,
|
||||||
|
is_target=index % 2 == 0,
|
||||||
|
class_name="pippip",
|
||||||
|
)
|
||||||
|
for index in range(100)
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
create_annotation(
|
||||||
|
start_time=1.05 + 0.1 * index,
|
||||||
|
is_target=index % 3 == 0,
|
||||||
|
class_name="myomyo",
|
||||||
|
)
|
||||||
|
for index in range(100)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
evals = evaluator.evaluate([gt], [prediction])
|
||||||
|
metrics = evaluator.compute_metrics(evals)
|
||||||
|
|
||||||
|
assert metrics["classification/average_precision/pippip"] == pytest.approx(
|
||||||
|
0.5, abs=0.005
|
||||||
|
)
|
||||||
|
assert metrics["classification/average_precision/myomyo"] == pytest.approx(
|
||||||
|
0.371, abs=0.005
|
||||||
|
)
|
||||||
|
assert metrics["classification/mean_average_precision"] == pytest.approx(
|
||||||
|
(0.5 + 0.371) / 2, abs=0.005
|
||||||
|
)
|
||||||
50
tests/test_evaluate/test_tasks/test_detection.py
Normal file
50
tests/test_evaluate/test_tasks/test_detection.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.tasks import build_task
|
||||||
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
|
from batdetect2.typing import ClipDetections
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def test_detection(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
create_detection,
|
||||||
|
create_annotation,
|
||||||
|
):
|
||||||
|
config = DetectionTaskConfig.model_validate(
|
||||||
|
{
|
||||||
|
"name": "sound_event_detection",
|
||||||
|
"metrics": [{"name": "average_precision"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
evaluator = build_task(config, targets=sample_targets)
|
||||||
|
|
||||||
|
# Create a dummy prediction
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
create_detection(start_time=1 + 0.1 * index, detection_score=score)
|
||||||
|
for index, score in enumerate(np.linspace(0, 1, 100))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a dummy annotation
|
||||||
|
gt = data.ClipAnnotation(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=[
|
||||||
|
# Only half of the annotations are targets
|
||||||
|
create_annotation(
|
||||||
|
start_time=1 + 0.1 * index,
|
||||||
|
is_target=index % 2 == 0,
|
||||||
|
)
|
||||||
|
for index in range(100)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the task
|
||||||
|
evals = evaluator.evaluate([gt], [prediction])
|
||||||
|
metrics = evaluator.compute_metrics(evals)
|
||||||
|
assert metrics["detection/average_precision"] == pytest.approx(0.5)
|
||||||
@ -14,7 +14,7 @@ from batdetect2.postprocess.decoding import (
|
|||||||
get_generic_tags,
|
get_generic_tags,
|
||||||
get_prediction_features,
|
get_prediction_features,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import Detection, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -209,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_raw_predictions() -> List[RawPrediction]:
|
def sample_raw_predictions() -> List[Detection]:
|
||||||
"""Manually crafted RawPrediction objects using the actual type."""
|
"""Manually crafted RawPrediction objects using the actual type."""
|
||||||
|
|
||||||
pred1_classes = xr.DataArray(
|
pred1_classes = xr.DataArray(
|
||||||
@ -220,7 +220,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||||
dims=["feature"],
|
dims=["feature"],
|
||||||
)
|
)
|
||||||
pred1 = RawPrediction(
|
pred1 = Detection(
|
||||||
detection_score=0.9,
|
detection_score=0.9,
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=[
|
coordinates=[
|
||||||
@ -242,7 +242,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||||
dims=["feature"],
|
dims=["feature"],
|
||||||
)
|
)
|
||||||
pred2 = RawPrediction(
|
pred2 = Detection(
|
||||||
detection_score=0.8,
|
detection_score=0.8,
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=[
|
coordinates=[
|
||||||
@ -264,7 +264,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||||
dims=["feature"],
|
dims=["feature"],
|
||||||
)
|
)
|
||||||
pred3 = RawPrediction(
|
pred3 = Detection(
|
||||||
detection_score=0.15,
|
detection_score=0.15,
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=[
|
coordinates=[
|
||||||
@ -281,7 +281,7 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_basic(
|
def test_convert_raw_to_sound_event_basic(
|
||||||
sample_raw_predictions: List[RawPrediction],
|
sample_raw_predictions: List[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -324,7 +324,7 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_thresholding(
|
def test_convert_raw_to_sound_event_thresholding(
|
||||||
sample_raw_predictions: List[RawPrediction],
|
sample_raw_predictions: List[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -352,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_no_threshold(
|
def test_convert_raw_to_sound_event_no_threshold(
|
||||||
sample_raw_predictions: List[RawPrediction],
|
sample_raw_predictions: List[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -380,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_top_class(
|
def test_convert_raw_to_sound_event_top_class(
|
||||||
sample_raw_predictions: List[RawPrediction],
|
sample_raw_predictions: List[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -407,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||||
sample_raw_predictions: List[RawPrediction],
|
sample_raw_predictions: List[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -433,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_list_to_clip_basic(
|
def test_convert_raw_list_to_clip_basic(
|
||||||
sample_raw_predictions: List[RawPrediction],
|
sample_raw_predictions: List[Detection],
|
||||||
sample_clip: data.Clip,
|
sample_clip: data.Clip,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -485,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_list_to_clip_passes_args(
|
def test_convert_raw_list_to_clip_passes_args(
|
||||||
sample_raw_predictions: List[RawPrediction],
|
sample_raw_predictions: List[Detection],
|
||||||
sample_clip: data.Clip,
|
sample_clip: data.Clip,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user