diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 48b8293..f5d7da9 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -53,7 +53,15 @@ def run_evaluate( num_workers=num_workers, ) - evaluator = build_evaluator(config=evaluation_config, targets=targets) + output_transform = build_output_transform( + config=output_config.transform, + targets=targets, + ) + evaluator = build_evaluator( + config=evaluation_config, + targets=targets, + transform=output_transform, + ) logger = build_logger( evaluation_config.logger, @@ -61,14 +69,9 @@ def run_evaluate( experiment_name=experiment_name, run_name=run_name, ) - output_transform = build_output_transform( - config=output_config.transform, - targets=targets, - ) module = EvaluationModule( model, evaluator, - output_transform=output_transform, ) trainer = Trainer(logger=logger, enable_checkpointing=False) metrics = trainer.test(module, loader) diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index 30152b9..685079f 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -5,8 +5,9 @@ from soundevent import data from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.tasks import build_task -from batdetect2.evaluate.types import EvaluatorProtocol -from batdetect2.postprocess.types import ClipDetections +from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol +from batdetect2.outputs import OutputTransformProtocol, build_output_transform +from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor from batdetect2.targets import build_targets from batdetect2.targets.types import TargetProtocol @@ -20,11 +21,23 @@ class Evaluator: def __init__( self, targets: TargetProtocol, - tasks: Sequence[EvaluatorProtocol], + transform: OutputTransformProtocol, + tasks: Sequence[EvaluationTaskProtocol], ): self.targets = targets + self.transform = transform self.tasks = tasks + def to_clip_detections_batch( + self, + clip_detections: Sequence[ClipDetectionsTensor], + clips: Sequence[data.Clip], + ) -> list[ClipDetections]: + return [ + self.transform.to_clip_detections(detections=dets, clip=clip) + for dets, clip in zip(clip_detections, clips, strict=False) + ] + def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], @@ -54,6 +67,7 @@ class Evaluator: def build_evaluator( config: EvaluationConfig | dict | None = None, targets: TargetProtocol | None = None, + transform: OutputTransformProtocol | None = None, ) -> EvaluatorProtocol: targets = targets or build_targets() @@ -63,7 +77,10 @@ def build_evaluator( if not isinstance(config, EvaluationConfig): config = EvaluationConfig.model_validate(config) + transform = transform or build_output_transform(targets=targets) + return Evaluator( targets=targets, + transform=transform, tasks=[build_task(task, targets=targets) for task in config.tasks], ) diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index 84ee763..48703d5 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -8,7 +8,6 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import get_image_logger from batdetect2.models import Model -from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess.types import ClipDetections @@ -17,15 +16,11 @@ class EvaluationModule(LightningModule): self, model: Model, evaluator: EvaluatorProtocol, - output_transform: OutputTransformProtocol | None = None, ): super().__init__() self.model = model self.evaluator = evaluator - self.output_transform = output_transform or build_output_transform( - targets=evaluator.targets - ) self.clip_annotations: List[data.ClipAnnotation] = [] self.predictions: List[ClipDetections] = [] @@ -39,15 +34,10 @@ class EvaluationModule(LightningModule): outputs = self.model.detector(batch.spec) clip_detections = self.model.postprocessor(outputs) - predictions = [ - self.output_transform.to_clip_detections( - detections=clip_dets, - clip=clip_annotation.clip, - ) - for clip_annotation, clip_dets in zip( - clip_annotations, clip_detections, strict=False - ) - ] + predictions = self.evaluator.to_clip_detections_batch( + clip_detections, + [clip_annotation.clip for clip_annotation in clip_annotations], + ) self.clip_annotations.extend(clip_annotations) self.predictions.extend(predictions) diff --git a/src/batdetect2/evaluate/tasks/__init__.py b/src/batdetect2/evaluate/tasks/__init__.py index 11b3f01..c35cdf6 100644 --- a/src/batdetect2/evaluate/tasks/__init__.py +++ b/src/batdetect2/evaluate/tasks/__init__.py @@ -11,7 +11,7 @@ from batdetect2.evaluate.tasks.clip_classification import ( from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig -from batdetect2.evaluate.types import EvaluatorProtocol +from batdetect2.evaluate.types import EvaluationTaskProtocol from batdetect2.postprocess.types import ClipDetections from batdetect2.targets import build_targets from batdetect2.targets.types import TargetProtocol @@ -36,7 +36,7 @@ TaskConfig = Annotated[ def build_task( config: TaskConfig, targets: TargetProtocol | None = None, -) -> EvaluatorProtocol: +) -> EvaluationTaskProtocol: targets = targets or build_targets() return tasks_registry.build(config, targets) diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 77065cd..f9cdcaa 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -28,7 +28,7 @@ from batdetect2.evaluate.affinity import ( ) from batdetect2.evaluate.types import ( AffinityFunction, - EvaluatorProtocol, + EvaluationTaskProtocol, ) from batdetect2.postprocess.types import ClipDetections, Detection from batdetect2.targets.types import TargetProtocol @@ -39,7 +39,7 @@ __all__ = [ "TaskImportConfig", ] -tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry( +tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry( "tasks" ) @@ -64,7 +64,7 @@ class BaseTaskConfig(BaseConfig): ignore_start_end: float = 0.01 -class BaseTask(EvaluatorProtocol, Generic[T_Output]): +class BaseTask(EvaluationTaskProtocol, Generic[T_Output]): targets: TargetProtocol metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] diff --git a/src/batdetect2/evaluate/types.py b/src/batdetect2/evaluate/types.py index 58f1c86..7fbe3d9 100644 --- a/src/batdetect2/evaluate/types.py +++ b/src/batdetect2/evaluate/types.py @@ -4,12 +4,18 @@ from typing import Generic, Iterable, Protocol, Sequence, TypeVar from matplotlib.figure import Figure from soundevent import data -from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.outputs.types import OutputTransformProtocol +from batdetect2.postprocess.types import ( + ClipDetections, + ClipDetectionsTensor, + Detection, +) from batdetect2.targets.types import TargetProtocol __all__ = [ "AffinityFunction", "ClipMatches", + "EvaluationTaskProtocol", "EvaluatorProtocol", "MatchEvaluation", "MatcherProtocol", @@ -94,7 +100,7 @@ class PlotterProtocol(Protocol): EvaluationOutput = TypeVar("EvaluationOutput") -class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): +class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]): targets: TargetProtocol def evaluate( @@ -112,3 +118,30 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): self, eval_outputs: EvaluationOutput, ) -> Iterable[tuple[str, Figure]]: ... + + +class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): + targets: TargetProtocol + transform: OutputTransformProtocol + + def to_clip_detections_batch( + self, + clip_detections: Sequence[ClipDetectionsTensor], + clips: Sequence[data.Clip], + ) -> list[ClipDetections]: ... + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[ClipDetections], + ) -> EvaluationOutput: ... + + def compute_metrics( + self, + eval_outputs: EvaluationOutput, + ) -> dict[str, float]: ... + + def generate_plots( + self, + eval_outputs: EvaluationOutput, + ) -> Iterable[tuple[str, Figure]]: ... diff --git a/src/batdetect2/outputs/transforms/__init__.py b/src/batdetect2/outputs/transforms/__init__.py index dfcf5b3..e82f84e 100644 --- a/src/batdetect2/outputs/transforms/__init__.py +++ b/src/batdetect2/outputs/transforms/__init__.py @@ -123,18 +123,20 @@ class OutputTransform(OutputTransformProtocol): out.append(transformed) - return [] + return out def transform_detection( self, detection: Detection, ) -> Detection | None: for transform in self.detection_transform_steps: - detection = transform(detection) # type: ignore + transformed = transform(detection) - if detection is None: + if transformed is None: return None + detection = transformed + return detection def transform_clip_detections( diff --git a/tests/test_outputs/test_transform/test_transform.py b/tests/test_outputs/test_transform/test_transform.py index 840473f..4fe3de1 100644 --- a/tests/test_outputs/test_transform/test_transform.py +++ b/tests/test_outputs/test_transform/test_transform.py @@ -18,7 +18,8 @@ from batdetect2.targets.types import TargetProtocol def _mock_clip_detections_tensor() -> ClipDetectionsTensor: return ClipDetectionsTensor( scores=torch.tensor([0.9], dtype=torch.float32), - sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32), + # NOTE: Time is scaled by 1000 + sizes=torch.tensor([[100, 1_000.0]], dtype=torch.float32), class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32), times=torch.tensor([0.2], dtype=torch.float32), frequencies=torch.tensor([60_000.0], dtype=torch.float32), @@ -26,29 +27,15 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor: ) -def test_shift_time_to_clip_start( - clip: data.Clip, - sample_targets: TargetProtocol, -): - clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) +def test_shift_time_to_clip_start(sample_targets: TargetProtocol): + raw = _mock_clip_detections_tensor() + transform = build_output_transform(targets=sample_targets) - detection = Detection( - geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]), - detection_score=0.9, - class_scores=np.array([0.9]), - features=np.array([1.0, 2.0]), - ) + transformed = transform.to_detections(raw, start_time=2.5) + start_time, _, end_time, _ = compute_bounds(transformed[0].geometry) - transformed = OutputTransform(targets=sample_targets)( - [ClipDetections(clip=clip, detections=[detection])] - )[0] - - start_time, _, end_time, _ = compute_bounds( - transformed.detections[0].geometry - ) - - assert np.isclose(start_time, 2.6) - assert np.isclose(end_time, 2.7) + assert np.isclose(start_time, 2.7) + assert np.isclose(end_time, 2.8) def test_to_clip_detections_shifts_by_clip_start( @@ -58,14 +45,10 @@ def test_to_clip_detections_shifts_by_clip_start( clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) transform = build_output_transform(targets=sample_targets) raw = _mock_clip_detections_tensor() - shifted = transform.to_clip_detections(detections=raw, clip=clip) - unshifted = transform.to_detections(detections=raw, start_time=0) - - shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry) - unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry) - - assert np.isclose(shifted_start - unshifted_start, clip.start_time) + start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry) + assert np.isclose(start_time, 2.7) + assert np.isclose(end_time, 2.8) def test_detection_and_clip_transforms_applied_in_order(