Create EvaluateTaskProtocol

This commit is contained in:
mbsantiago 2026-03-18 01:53:28 +00:00
parent daff74fdde
commit b3af70761e
8 changed files with 90 additions and 62 deletions

View File

@ -53,7 +53,15 @@ def run_evaluate(
num_workers=num_workers, num_workers=num_workers,
) )
evaluator = build_evaluator(config=evaluation_config, targets=targets) output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
evaluator = build_evaluator(
config=evaluation_config,
targets=targets,
transform=output_transform,
)
logger = build_logger( logger = build_logger(
evaluation_config.logger, evaluation_config.logger,
@ -61,14 +69,9 @@ def run_evaluate(
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
) )
output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
module = EvaluationModule( module = EvaluationModule(
model, model,
evaluator, evaluator,
output_transform=output_transform,
) )
trainer = Trainer(logger=logger, enable_checkpointing=False) trainer = Trainer(logger=logger, enable_checkpointing=False)
metrics = trainer.test(module, loader) metrics = trainer.test(module, loader)

View File

@ -5,8 +5,9 @@ from soundevent import data
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
from batdetect2.postprocess.types import ClipDetections from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -20,11 +21,23 @@ class Evaluator:
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
tasks: Sequence[EvaluatorProtocol], transform: OutputTransformProtocol,
tasks: Sequence[EvaluationTaskProtocol],
): ):
self.targets = targets self.targets = targets
self.transform = transform
self.tasks = tasks self.tasks = tasks
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]:
return [
self.transform.to_clip_detections(detections=dets, clip=clip)
for dets, clip in zip(clip_detections, clips, strict=False)
]
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
@ -54,6 +67,7 @@ class Evaluator:
def build_evaluator( def build_evaluator(
config: EvaluationConfig | dict | None = None, config: EvaluationConfig | dict | None = None,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
transform: OutputTransformProtocol | None = None,
) -> EvaluatorProtocol: ) -> EvaluatorProtocol:
targets = targets or build_targets() targets = targets or build_targets()
@ -63,7 +77,10 @@ def build_evaluator(
if not isinstance(config, EvaluationConfig): if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config) config = EvaluationConfig.model_validate(config)
transform = transform or build_output_transform(targets=targets)
return Evaluator( return Evaluator(
targets=targets, targets=targets,
transform=transform,
tasks=[build_task(task, targets=targets) for task in config.tasks], tasks=[build_task(task, targets=targets) for task in config.tasks],
) )

View File

@ -8,7 +8,6 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
@ -17,15 +16,11 @@ class EvaluationModule(LightningModule):
self, self,
model: Model, model: Model,
evaluator: EvaluatorProtocol, evaluator: EvaluatorProtocol,
output_transform: OutputTransformProtocol | None = None,
): ):
super().__init__() super().__init__()
self.model = model self.model = model
self.evaluator = evaluator self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform(
targets=evaluator.targets
)
self.clip_annotations: List[data.ClipAnnotation] = [] self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[ClipDetections] = [] self.predictions: List[ClipDetections] = []
@ -39,15 +34,10 @@ class EvaluationModule(LightningModule):
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(outputs) clip_detections = self.model.postprocessor(outputs)
predictions = [ predictions = self.evaluator.to_clip_detections_batch(
self.output_transform.to_clip_detections( clip_detections,
detections=clip_dets, [clip_annotation.clip for clip_annotation in clip_annotations],
clip=clip_annotation.clip,
) )
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections, strict=False
)
]
self.clip_annotations.extend(clip_annotations) self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions) self.predictions.extend(predictions)

View File

@ -11,7 +11,7 @@ from batdetect2.evaluate.tasks.clip_classification import (
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.evaluate.types import EvaluationTaskProtocol
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -36,7 +36,7 @@ TaskConfig = Annotated[
def build_task( def build_task(
config: TaskConfig, config: TaskConfig,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
) -> EvaluatorProtocol: ) -> EvaluationTaskProtocol:
targets = targets or build_targets() targets = targets or build_targets()
return tasks_registry.build(config, targets) return tasks_registry.build(config, targets)

View File

@ -28,7 +28,7 @@ from batdetect2.evaluate.affinity import (
) )
from batdetect2.evaluate.types import ( from batdetect2.evaluate.types import (
AffinityFunction, AffinityFunction,
EvaluatorProtocol, EvaluationTaskProtocol,
) )
from batdetect2.postprocess.types import ClipDetections, Detection from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -39,7 +39,7 @@ __all__ = [
"TaskImportConfig", "TaskImportConfig",
] ]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry( tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
"tasks" "tasks"
) )
@ -64,7 +64,7 @@ class BaseTaskConfig(BaseConfig):
ignore_start_end: float = 0.01 ignore_start_end: float = 0.01
class BaseTask(EvaluatorProtocol, Generic[T_Output]): class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
targets: TargetProtocol targets: TargetProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]

View File

@ -4,12 +4,18 @@ from typing import Generic, Iterable, Protocol, Sequence, TypeVar
from matplotlib.figure import Figure from matplotlib.figure import Figure
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import ClipDetections, Detection from batdetect2.outputs.types import OutputTransformProtocol
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"AffinityFunction", "AffinityFunction",
"ClipMatches", "ClipMatches",
"EvaluationTaskProtocol",
"EvaluatorProtocol", "EvaluatorProtocol",
"MatchEvaluation", "MatchEvaluation",
"MatcherProtocol", "MatcherProtocol",
@ -94,7 +100,7 @@ class PlotterProtocol(Protocol):
EvaluationOutput = TypeVar("EvaluationOutput") EvaluationOutput = TypeVar("EvaluationOutput")
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol targets: TargetProtocol
def evaluate( def evaluate(
@ -112,3 +118,30 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
self, self,
eval_outputs: EvaluationOutput, eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ... ) -> Iterable[tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol
transform: OutputTransformProtocol
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]: ...
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ...
def compute_metrics(
self,
eval_outputs: EvaluationOutput,
) -> dict[str, float]: ...
def generate_plots(
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...

View File

@ -123,18 +123,20 @@ class OutputTransform(OutputTransformProtocol):
out.append(transformed) out.append(transformed)
return [] return out
def transform_detection( def transform_detection(
self, self,
detection: Detection, detection: Detection,
) -> Detection | None: ) -> Detection | None:
for transform in self.detection_transform_steps: for transform in self.detection_transform_steps:
detection = transform(detection) # type: ignore transformed = transform(detection)
if detection is None: if transformed is None:
return None return None
detection = transformed
return detection return detection
def transform_clip_detections( def transform_clip_detections(

View File

@ -18,7 +18,8 @@ from batdetect2.targets.types import TargetProtocol
def _mock_clip_detections_tensor() -> ClipDetectionsTensor: def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
return ClipDetectionsTensor( return ClipDetectionsTensor(
scores=torch.tensor([0.9], dtype=torch.float32), scores=torch.tensor([0.9], dtype=torch.float32),
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32), # NOTE: Time is scaled by 1000
sizes=torch.tensor([[100, 1_000.0]], dtype=torch.float32),
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32), class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
times=torch.tensor([0.2], dtype=torch.float32), times=torch.tensor([0.2], dtype=torch.float32),
frequencies=torch.tensor([60_000.0], dtype=torch.float32), frequencies=torch.tensor([60_000.0], dtype=torch.float32),
@ -26,29 +27,15 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
) )
def test_shift_time_to_clip_start( def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
clip: data.Clip, raw = _mock_clip_detections_tensor()
sample_targets: TargetProtocol, transform = build_output_transform(targets=sample_targets)
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection( transformed = transform.to_detections(raw, start_time=2.5)
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]), start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
transformed = OutputTransform(targets=sample_targets)( assert np.isclose(start_time, 2.7)
[ClipDetections(clip=clip, detections=[detection])] assert np.isclose(end_time, 2.8)
)[0]
start_time, _, end_time, _ = compute_bounds(
transformed.detections[0].geometry
)
assert np.isclose(start_time, 2.6)
assert np.isclose(end_time, 2.7)
def test_to_clip_detections_shifts_by_clip_start( def test_to_clip_detections_shifts_by_clip_start(
@ -58,14 +45,10 @@ def test_to_clip_detections_shifts_by_clip_start(
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
transform = build_output_transform(targets=sample_targets) transform = build_output_transform(targets=sample_targets)
raw = _mock_clip_detections_tensor() raw = _mock_clip_detections_tensor()
shifted = transform.to_clip_detections(detections=raw, clip=clip) shifted = transform.to_clip_detections(detections=raw, clip=clip)
unshifted = transform.to_detections(detections=raw, start_time=0) start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
assert np.isclose(start_time, 2.7)
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry) assert np.isclose(end_time, 2.8)
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
def test_detection_and_clip_transforms_applied_in_order( def test_detection_and_clip_transforms_applied_in_order(