mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Create EvaluateTaskProtocol
This commit is contained in:
parent
daff74fdde
commit
b3af70761e
@ -53,7 +53,15 @@ def run_evaluate(
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=evaluation_config, targets=targets)
|
||||
output_transform = build_output_transform(
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
evaluator = build_evaluator(
|
||||
config=evaluation_config,
|
||||
targets=targets,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
logger = build_logger(
|
||||
evaluation_config.logger,
|
||||
@ -61,14 +69,9 @@ def run_evaluate(
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
module = EvaluationModule(
|
||||
model,
|
||||
evaluator,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||
metrics = trainer.test(module, loader)
|
||||
|
||||
@ -5,8 +5,9 @@ from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
@ -20,11 +21,23 @@ class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
tasks: Sequence[EvaluatorProtocol],
|
||||
transform: OutputTransformProtocol,
|
||||
tasks: Sequence[EvaluationTaskProtocol],
|
||||
):
|
||||
self.targets = targets
|
||||
self.transform = transform
|
||||
self.tasks = tasks
|
||||
|
||||
def to_clip_detections_batch(
|
||||
self,
|
||||
clip_detections: Sequence[ClipDetectionsTensor],
|
||||
clips: Sequence[data.Clip],
|
||||
) -> list[ClipDetections]:
|
||||
return [
|
||||
self.transform.to_clip_detections(detections=dets, clip=clip)
|
||||
for dets, clip in zip(clip_detections, clips, strict=False)
|
||||
]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
@ -54,6 +67,7 @@ class Evaluator:
|
||||
def build_evaluator(
|
||||
config: EvaluationConfig | dict | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
transform: OutputTransformProtocol | None = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
@ -63,7 +77,10 @@ def build_evaluator(
|
||||
if not isinstance(config, EvaluationConfig):
|
||||
config = EvaluationConfig.model_validate(config)
|
||||
|
||||
transform = transform or build_output_transform(targets=targets)
|
||||
|
||||
return Evaluator(
|
||||
targets=targets,
|
||||
transform=transform,
|
||||
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||
)
|
||||
|
||||
@ -8,7 +8,6 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
@ -17,15 +16,11 @@ class EvaluationModule(LightningModule):
|
||||
self,
|
||||
model: Model,
|
||||
evaluator: EvaluatorProtocol,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.evaluator = evaluator
|
||||
self.output_transform = output_transform or build_output_transform(
|
||||
targets=evaluator.targets
|
||||
)
|
||||
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[ClipDetections] = []
|
||||
@ -39,15 +34,10 @@ class EvaluationModule(LightningModule):
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
predictions = [
|
||||
self.output_transform.to_clip_detections(
|
||||
detections=clip_dets,
|
||||
clip=clip_annotation.clip,
|
||||
predictions = self.evaluator.to_clip_detections_batch(
|
||||
clip_detections,
|
||||
[clip_annotation.clip for clip_annotation in clip_annotations],
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
self.clip_annotations.extend(clip_annotations)
|
||||
self.predictions.extend(predictions)
|
||||
|
||||
@ -11,7 +11,7 @@ from batdetect2.evaluate.tasks.clip_classification import (
|
||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.evaluate.types import EvaluationTaskProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
@ -36,7 +36,7 @@ TaskConfig = Annotated[
|
||||
def build_task(
|
||||
config: TaskConfig,
|
||||
targets: TargetProtocol | None = None,
|
||||
) -> EvaluatorProtocol:
|
||||
) -> EvaluationTaskProtocol:
|
||||
targets = targets or build_targets()
|
||||
return tasks_registry.build(config, targets)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from batdetect2.evaluate.affinity import (
|
||||
)
|
||||
from batdetect2.evaluate.types import (
|
||||
AffinityFunction,
|
||||
EvaluatorProtocol,
|
||||
EvaluationTaskProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
@ -39,7 +39,7 @@ __all__ = [
|
||||
"TaskImportConfig",
|
||||
]
|
||||
|
||||
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||
tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
|
||||
"tasks"
|
||||
)
|
||||
|
||||
@ -64,7 +64,7 @@ class BaseTaskConfig(BaseConfig):
|
||||
ignore_start_end: float = 0.01
|
||||
|
||||
|
||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
||||
targets: TargetProtocol
|
||||
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||
|
||||
@ -4,12 +4,18 @@ from typing import Generic, Iterable, Protocol, Sequence, TypeVar
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.outputs.types import OutputTransformProtocol
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"AffinityFunction",
|
||||
"ClipMatches",
|
||||
"EvaluationTaskProtocol",
|
||||
"EvaluatorProtocol",
|
||||
"MatchEvaluation",
|
||||
"MatcherProtocol",
|
||||
@ -94,7 +100,7 @@ class PlotterProtocol(Protocol):
|
||||
EvaluationOutput = TypeVar("EvaluationOutput")
|
||||
|
||||
|
||||
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
targets: TargetProtocol
|
||||
|
||||
def evaluate(
|
||||
@ -112,3 +118,30 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
targets: TargetProtocol
|
||||
transform: OutputTransformProtocol
|
||||
|
||||
def to_clip_detections_batch(
|
||||
self,
|
||||
clip_detections: Sequence[ClipDetectionsTensor],
|
||||
clips: Sequence[data.Clip],
|
||||
) -> list[ClipDetections]: ...
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> EvaluationOutput: ...
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
|
||||
@ -123,18 +123,20 @@ class OutputTransform(OutputTransformProtocol):
|
||||
|
||||
out.append(transformed)
|
||||
|
||||
return []
|
||||
return out
|
||||
|
||||
def transform_detection(
|
||||
self,
|
||||
detection: Detection,
|
||||
) -> Detection | None:
|
||||
for transform in self.detection_transform_steps:
|
||||
detection = transform(detection) # type: ignore
|
||||
transformed = transform(detection)
|
||||
|
||||
if detection is None:
|
||||
if transformed is None:
|
||||
return None
|
||||
|
||||
detection = transformed
|
||||
|
||||
return detection
|
||||
|
||||
def transform_clip_detections(
|
||||
|
||||
@ -18,7 +18,8 @@ from batdetect2.targets.types import TargetProtocol
|
||||
def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
||||
return ClipDetectionsTensor(
|
||||
scores=torch.tensor([0.9], dtype=torch.float32),
|
||||
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32),
|
||||
# NOTE: Time is scaled by 1000
|
||||
sizes=torch.tensor([[100, 1_000.0]], dtype=torch.float32),
|
||||
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
|
||||
times=torch.tensor([0.2], dtype=torch.float32),
|
||||
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
|
||||
@ -26,29 +27,15 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
||||
)
|
||||
|
||||
|
||||
def test_shift_time_to_clip_start(
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
):
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
|
||||
raw = _mock_clip_detections_tensor()
|
||||
transform = build_output_transform(targets=sample_targets)
|
||||
|
||||
detection = Detection(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
||||
detection_score=0.9,
|
||||
class_scores=np.array([0.9]),
|
||||
features=np.array([1.0, 2.0]),
|
||||
)
|
||||
transformed = transform.to_detections(raw, start_time=2.5)
|
||||
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
|
||||
|
||||
transformed = OutputTransform(targets=sample_targets)(
|
||||
[ClipDetections(clip=clip, detections=[detection])]
|
||||
)[0]
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(
|
||||
transformed.detections[0].geometry
|
||||
)
|
||||
|
||||
assert np.isclose(start_time, 2.6)
|
||||
assert np.isclose(end_time, 2.7)
|
||||
assert np.isclose(start_time, 2.7)
|
||||
assert np.isclose(end_time, 2.8)
|
||||
|
||||
|
||||
def test_to_clip_detections_shifts_by_clip_start(
|
||||
@ -58,14 +45,10 @@ def test_to_clip_detections_shifts_by_clip_start(
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
transform = build_output_transform(targets=sample_targets)
|
||||
raw = _mock_clip_detections_tensor()
|
||||
|
||||
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
||||
unshifted = transform.to_detections(detections=raw, start_time=0)
|
||||
|
||||
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry)
|
||||
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
|
||||
|
||||
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
|
||||
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
|
||||
assert np.isclose(start_time, 2.7)
|
||||
assert np.isclose(end_time, 2.8)
|
||||
|
||||
|
||||
def test_detection_and_clip_transforms_applied_in_order(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user