mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Create EvaluateTaskProtocol
This commit is contained in:
parent
daff74fdde
commit
b3af70761e
@ -53,7 +53,15 @@ def run_evaluate(
|
|||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(config=evaluation_config, targets=targets)
|
output_transform = build_output_transform(
|
||||||
|
config=output_config.transform,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
evaluator = build_evaluator(
|
||||||
|
config=evaluation_config,
|
||||||
|
targets=targets,
|
||||||
|
transform=output_transform,
|
||||||
|
)
|
||||||
|
|
||||||
logger = build_logger(
|
logger = build_logger(
|
||||||
evaluation_config.logger,
|
evaluation_config.logger,
|
||||||
@ -61,14 +69,9 @@ def run_evaluate(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
output_transform = build_output_transform(
|
|
||||||
config=output_config.transform,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
module = EvaluationModule(
|
module = EvaluationModule(
|
||||||
model,
|
model,
|
||||||
evaluator,
|
evaluator,
|
||||||
output_transform=output_transform,
|
|
||||||
)
|
)
|
||||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||||
metrics = trainer.test(module, loader)
|
metrics = trainer.test(module, loader)
|
||||||
|
|||||||
@ -5,8 +5,9 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.evaluate.tasks import build_task
|
from batdetect2.evaluate.tasks import build_task
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
|
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
@ -20,11 +21,23 @@ class Evaluator:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
tasks: Sequence[EvaluatorProtocol],
|
transform: OutputTransformProtocol,
|
||||||
|
tasks: Sequence[EvaluationTaskProtocol],
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
self.transform = transform
|
||||||
self.tasks = tasks
|
self.tasks = tasks
|
||||||
|
|
||||||
|
def to_clip_detections_batch(
|
||||||
|
self,
|
||||||
|
clip_detections: Sequence[ClipDetectionsTensor],
|
||||||
|
clips: Sequence[data.Clip],
|
||||||
|
) -> list[ClipDetections]:
|
||||||
|
return [
|
||||||
|
self.transform.to_clip_detections(detections=dets, clip=clip)
|
||||||
|
for dets, clip in zip(clip_detections, clips, strict=False)
|
||||||
|
]
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
@ -54,6 +67,7 @@ class Evaluator:
|
|||||||
def build_evaluator(
|
def build_evaluator(
|
||||||
config: EvaluationConfig | dict | None = None,
|
config: EvaluationConfig | dict | None = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
transform: OutputTransformProtocol | None = None,
|
||||||
) -> EvaluatorProtocol:
|
) -> EvaluatorProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
|
|
||||||
@ -63,7 +77,10 @@ def build_evaluator(
|
|||||||
if not isinstance(config, EvaluationConfig):
|
if not isinstance(config, EvaluationConfig):
|
||||||
config = EvaluationConfig.model_validate(config)
|
config = EvaluationConfig.model_validate(config)
|
||||||
|
|
||||||
|
transform = transform or build_output_transform(targets=targets)
|
||||||
|
|
||||||
return Evaluator(
|
return Evaluator(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
transform=transform,
|
||||||
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
|
|||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
@ -17,15 +16,11 @@ class EvaluationModule(LightningModule):
|
|||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.output_transform = output_transform or build_output_transform(
|
|
||||||
targets=evaluator.targets
|
|
||||||
)
|
|
||||||
|
|
||||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self.predictions: List[ClipDetections] = []
|
self.predictions: List[ClipDetections] = []
|
||||||
@ -39,15 +34,10 @@ class EvaluationModule(LightningModule):
|
|||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
clip_detections = self.model.postprocessor(outputs)
|
clip_detections = self.model.postprocessor(outputs)
|
||||||
predictions = [
|
predictions = self.evaluator.to_clip_detections_batch(
|
||||||
self.output_transform.to_clip_detections(
|
clip_detections,
|
||||||
detections=clip_dets,
|
[clip_annotation.clip for clip_annotation in clip_annotations],
|
||||||
clip=clip_annotation.clip,
|
|
||||||
)
|
)
|
||||||
for clip_annotation, clip_dets in zip(
|
|
||||||
clip_annotations, clip_detections, strict=False
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.clip_annotations.extend(clip_annotations)
|
self.clip_annotations.extend(clip_annotations)
|
||||||
self.predictions.extend(predictions)
|
self.predictions.extend(predictions)
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from batdetect2.evaluate.tasks.clip_classification import (
|
|||||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluationTaskProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
@ -36,7 +36,7 @@ TaskConfig = Annotated[
|
|||||||
def build_task(
|
def build_task(
|
||||||
config: TaskConfig,
|
config: TaskConfig,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
) -> EvaluatorProtocol:
|
) -> EvaluationTaskProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
return tasks_registry.build(config, targets)
|
return tasks_registry.build(config, targets)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from batdetect2.evaluate.affinity import (
|
|||||||
)
|
)
|
||||||
from batdetect2.evaluate.types import (
|
from batdetect2.evaluate.types import (
|
||||||
AffinityFunction,
|
AffinityFunction,
|
||||||
EvaluatorProtocol,
|
EvaluationTaskProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
@ -39,7 +39,7 @@ __all__ = [
|
|||||||
"TaskImportConfig",
|
"TaskImportConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
|
||||||
"tasks"
|
"tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ class BaseTaskConfig(BaseConfig):
|
|||||||
ignore_start_end: float = 0.01
|
ignore_start_end: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||||
|
|||||||
@ -4,12 +4,18 @@ from typing import Generic, Iterable, Protocol, Sequence, TypeVar
|
|||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
from batdetect2.outputs.types import OutputTransformProtocol
|
||||||
|
from batdetect2.postprocess.types import (
|
||||||
|
ClipDetections,
|
||||||
|
ClipDetectionsTensor,
|
||||||
|
Detection,
|
||||||
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AffinityFunction",
|
"AffinityFunction",
|
||||||
"ClipMatches",
|
"ClipMatches",
|
||||||
|
"EvaluationTaskProtocol",
|
||||||
"EvaluatorProtocol",
|
"EvaluatorProtocol",
|
||||||
"MatchEvaluation",
|
"MatchEvaluation",
|
||||||
"MatcherProtocol",
|
"MatcherProtocol",
|
||||||
@ -94,7 +100,7 @@ class PlotterProtocol(Protocol):
|
|||||||
EvaluationOutput = TypeVar("EvaluationOutput")
|
EvaluationOutput = TypeVar("EvaluationOutput")
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
@ -112,3 +118,30 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
|||||||
self,
|
self,
|
||||||
eval_outputs: EvaluationOutput,
|
eval_outputs: EvaluationOutput,
|
||||||
) -> Iterable[tuple[str, Figure]]: ...
|
) -> Iterable[tuple[str, Figure]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||||
|
targets: TargetProtocol
|
||||||
|
transform: OutputTransformProtocol
|
||||||
|
|
||||||
|
def to_clip_detections_batch(
|
||||||
|
self,
|
||||||
|
clip_detections: Sequence[ClipDetectionsTensor],
|
||||||
|
clips: Sequence[data.Clip],
|
||||||
|
) -> list[ClipDetections]: ...
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[ClipDetections],
|
||||||
|
) -> EvaluationOutput: ...
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self,
|
||||||
|
eval_outputs: EvaluationOutput,
|
||||||
|
) -> dict[str, float]: ...
|
||||||
|
|
||||||
|
def generate_plots(
|
||||||
|
self,
|
||||||
|
eval_outputs: EvaluationOutput,
|
||||||
|
) -> Iterable[tuple[str, Figure]]: ...
|
||||||
|
|||||||
@ -123,18 +123,20 @@ class OutputTransform(OutputTransformProtocol):
|
|||||||
|
|
||||||
out.append(transformed)
|
out.append(transformed)
|
||||||
|
|
||||||
return []
|
return out
|
||||||
|
|
||||||
def transform_detection(
|
def transform_detection(
|
||||||
self,
|
self,
|
||||||
detection: Detection,
|
detection: Detection,
|
||||||
) -> Detection | None:
|
) -> Detection | None:
|
||||||
for transform in self.detection_transform_steps:
|
for transform in self.detection_transform_steps:
|
||||||
detection = transform(detection) # type: ignore
|
transformed = transform(detection)
|
||||||
|
|
||||||
if detection is None:
|
if transformed is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
detection = transformed
|
||||||
|
|
||||||
return detection
|
return detection
|
||||||
|
|
||||||
def transform_clip_detections(
|
def transform_clip_detections(
|
||||||
|
|||||||
@ -18,7 +18,8 @@ from batdetect2.targets.types import TargetProtocol
|
|||||||
def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
||||||
return ClipDetectionsTensor(
|
return ClipDetectionsTensor(
|
||||||
scores=torch.tensor([0.9], dtype=torch.float32),
|
scores=torch.tensor([0.9], dtype=torch.float32),
|
||||||
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32),
|
# NOTE: Time is scaled by 1000
|
||||||
|
sizes=torch.tensor([[100, 1_000.0]], dtype=torch.float32),
|
||||||
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
|
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
|
||||||
times=torch.tensor([0.2], dtype=torch.float32),
|
times=torch.tensor([0.2], dtype=torch.float32),
|
||||||
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
|
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
|
||||||
@ -26,29 +27,15 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_shift_time_to_clip_start(
|
def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
|
||||||
clip: data.Clip,
|
raw = _mock_clip_detections_tensor()
|
||||||
sample_targets: TargetProtocol,
|
transform = build_output_transform(targets=sample_targets)
|
||||||
):
|
|
||||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
|
||||||
|
|
||||||
detection = Detection(
|
transformed = transform.to_detections(raw, start_time=2.5)
|
||||||
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
|
||||||
detection_score=0.9,
|
|
||||||
class_scores=np.array([0.9]),
|
|
||||||
features=np.array([1.0, 2.0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
transformed = OutputTransform(targets=sample_targets)(
|
assert np.isclose(start_time, 2.7)
|
||||||
[ClipDetections(clip=clip, detections=[detection])]
|
assert np.isclose(end_time, 2.8)
|
||||||
)[0]
|
|
||||||
|
|
||||||
start_time, _, end_time, _ = compute_bounds(
|
|
||||||
transformed.detections[0].geometry
|
|
||||||
)
|
|
||||||
|
|
||||||
assert np.isclose(start_time, 2.6)
|
|
||||||
assert np.isclose(end_time, 2.7)
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_clip_detections_shifts_by_clip_start(
|
def test_to_clip_detections_shifts_by_clip_start(
|
||||||
@ -58,14 +45,10 @@ def test_to_clip_detections_shifts_by_clip_start(
|
|||||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||||
transform = build_output_transform(targets=sample_targets)
|
transform = build_output_transform(targets=sample_targets)
|
||||||
raw = _mock_clip_detections_tensor()
|
raw = _mock_clip_detections_tensor()
|
||||||
|
|
||||||
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
||||||
unshifted = transform.to_detections(detections=raw, start_time=0)
|
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
|
||||||
|
assert np.isclose(start_time, 2.7)
|
||||||
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry)
|
assert np.isclose(end_time, 2.8)
|
||||||
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
|
|
||||||
|
|
||||||
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
|
|
||||||
|
|
||||||
|
|
||||||
def test_detection_and_clip_transforms_applied_in_order(
|
def test_detection_and_clip_transforms_applied_in_order(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user