Create EvaluateTaskProtocol

This commit is contained in:
mbsantiago 2026-03-18 01:53:28 +00:00
parent daff74fdde
commit b3af70761e
8 changed files with 90 additions and 62 deletions

View File

@ -53,7 +53,15 @@ def run_evaluate(
num_workers=num_workers,
)
evaluator = build_evaluator(config=evaluation_config, targets=targets)
output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
evaluator = build_evaluator(
config=evaluation_config,
targets=targets,
transform=output_transform,
)
logger = build_logger(
evaluation_config.logger,
@ -61,14 +69,9 @@ def run_evaluate(
experiment_name=experiment_name,
run_name=run_name,
)
output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
module = EvaluationModule(
model,
evaluator,
output_transform=output_transform,
)
trainer = Trainer(logger=logger, enable_checkpointing=False)
metrics = trainer.test(module, loader)

View File

@ -5,8 +5,9 @@ from soundevent import data
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
@ -20,11 +21,23 @@ class Evaluator:
def __init__(
self,
targets: TargetProtocol,
tasks: Sequence[EvaluatorProtocol],
transform: OutputTransformProtocol,
tasks: Sequence[EvaluationTaskProtocol],
):
self.targets = targets
self.transform = transform
self.tasks = tasks
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]:
return [
self.transform.to_clip_detections(detections=dets, clip=clip)
for dets, clip in zip(clip_detections, clips, strict=False)
]
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
@ -54,6 +67,7 @@ class Evaluator:
def build_evaluator(
config: EvaluationConfig | dict | None = None,
targets: TargetProtocol | None = None,
transform: OutputTransformProtocol | None = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
@ -63,7 +77,10 @@ def build_evaluator(
if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config)
transform = transform or build_output_transform(targets=targets)
return Evaluator(
targets=targets,
transform=transform,
tasks=[build_task(task, targets=targets) for task in config.tasks],
)

View File

@ -8,7 +8,6 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
@ -17,15 +16,11 @@ class EvaluationModule(LightningModule):
self,
model: Model,
evaluator: EvaluatorProtocol,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__()
self.model = model
self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform(
targets=evaluator.targets
)
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[ClipDetections] = []
@ -39,15 +34,10 @@ class EvaluationModule(LightningModule):
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(outputs)
predictions = [
self.output_transform.to_clip_detections(
detections=clip_dets,
clip=clip_annotation.clip,
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections, strict=False
)
]
predictions = self.evaluator.to_clip_detections_batch(
clip_detections,
[clip_annotation.clip for clip_annotation in clip_annotations],
)
self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions)

View File

@ -11,7 +11,7 @@ from batdetect2.evaluate.tasks.clip_classification import (
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.evaluate.types import EvaluationTaskProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
@ -36,7 +36,7 @@ TaskConfig = Annotated[
def build_task(
config: TaskConfig,
targets: TargetProtocol | None = None,
) -> EvaluatorProtocol:
) -> EvaluationTaskProtocol:
targets = targets or build_targets()
return tasks_registry.build(config, targets)

View File

@ -28,7 +28,7 @@ from batdetect2.evaluate.affinity import (
)
from batdetect2.evaluate.types import (
AffinityFunction,
EvaluatorProtocol,
EvaluationTaskProtocol,
)
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
@ -39,7 +39,7 @@ __all__ = [
"TaskImportConfig",
]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
"tasks"
)
@ -64,7 +64,7 @@ class BaseTaskConfig(BaseConfig):
ignore_start_end: float = 0.01
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
targets: TargetProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]

View File

@ -4,12 +4,18 @@ from typing import Generic, Iterable, Protocol, Sequence, TypeVar
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.outputs.types import OutputTransformProtocol
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"AffinityFunction",
"ClipMatches",
"EvaluationTaskProtocol",
"EvaluatorProtocol",
"MatchEvaluation",
"MatcherProtocol",
@ -94,7 +100,7 @@ class PlotterProtocol(Protocol):
EvaluationOutput = TypeVar("EvaluationOutput")
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol
def evaluate(
@ -112,3 +118,30 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol
transform: OutputTransformProtocol
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]: ...
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ...
def compute_metrics(
self,
eval_outputs: EvaluationOutput,
) -> dict[str, float]: ...
def generate_plots(
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...

View File

@ -123,18 +123,20 @@ class OutputTransform(OutputTransformProtocol):
out.append(transformed)
return []
return out
def transform_detection(
self,
detection: Detection,
) -> Detection | None:
for transform in self.detection_transform_steps:
detection = transform(detection) # type: ignore
transformed = transform(detection)
if detection is None:
if transformed is None:
return None
detection = transformed
return detection
def transform_clip_detections(

View File

@ -18,7 +18,8 @@ from batdetect2.targets.types import TargetProtocol
def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
return ClipDetectionsTensor(
scores=torch.tensor([0.9], dtype=torch.float32),
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32),
# NOTE: Time is scaled by 1000
sizes=torch.tensor([[100, 1_000.0]], dtype=torch.float32),
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
times=torch.tensor([0.2], dtype=torch.float32),
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
@ -26,29 +27,15 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
)
def test_shift_time_to_clip_start(
clip: data.Clip,
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
raw = _mock_clip_detections_tensor()
transform = build_output_transform(targets=sample_targets)
detection = Detection(
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
transformed = transform.to_detections(raw, start_time=2.5)
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
transformed = OutputTransform(targets=sample_targets)(
[ClipDetections(clip=clip, detections=[detection])]
)[0]
start_time, _, end_time, _ = compute_bounds(
transformed.detections[0].geometry
)
assert np.isclose(start_time, 2.6)
assert np.isclose(end_time, 2.7)
assert np.isclose(start_time, 2.7)
assert np.isclose(end_time, 2.8)
def test_to_clip_detections_shifts_by_clip_start(
@ -58,14 +45,10 @@ def test_to_clip_detections_shifts_by_clip_start(
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
transform = build_output_transform(targets=sample_targets)
raw = _mock_clip_detections_tensor()
shifted = transform.to_clip_detections(detections=raw, clip=clip)
unshifted = transform.to_detections(detections=raw, start_time=0)
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry)
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
assert np.isclose(start_time, 2.7)
assert np.isclose(end_time, 2.8)
def test_detection_and_clip_transforms_applied_in_order(