mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Separate targets and ROI
This commit is contained in:
parent
716b3a3778
commit
2b235e28bb
@ -32,5 +32,6 @@ classification_targets:
|
|||||||
value: Rhinolophus ferrumequinum
|
value: Rhinolophus ferrumequinum
|
||||||
|
|
||||||
roi:
|
roi:
|
||||||
name: anchor_bbox
|
default:
|
||||||
anchor: top-left
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
|||||||
2
justfile
2
justfile
@ -20,7 +20,7 @@ install:
|
|||||||
# Testing & Coverage
|
# Testing & Coverage
|
||||||
# Run tests using pytest.
|
# Run tests using pytest.
|
||||||
test:
|
test:
|
||||||
uv run pytest {{TESTS_DIR}}
|
uv run pytest -n auto {{TESTS_DIR}}
|
||||||
|
|
||||||
# Run tests and generate coverage data.
|
# Run tests and generate coverage data.
|
||||||
coverage:
|
coverage:
|
||||||
|
|||||||
@ -88,6 +88,7 @@ dev = [
|
|||||||
"pandas-stubs>=2.2.2.240807",
|
"pandas-stubs>=2.2.2.240807",
|
||||||
"python-lsp-server>=1.13.0",
|
"python-lsp-server>=1.13.0",
|
||||||
"deepdiff>=8.6.1",
|
"deepdiff>=8.6.1",
|
||||||
|
"pytest-xdist[psutil]>=3.8.0",
|
||||||
]
|
]
|
||||||
dvclive = ["dvclive>=3.48.2"]
|
dvclive = ["dvclive>=3.48.2"]
|
||||||
mlflow = ["mlflow>=3.1.1"]
|
mlflow = ["mlflow>=3.1.1"]
|
||||||
|
|||||||
@ -50,7 +50,13 @@ from batdetect2.postprocess import (
|
|||||||
build_postprocessor,
|
build_postprocessor,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
from batdetect2.targets import (
|
||||||
|
ROIMapperProtocol,
|
||||||
|
TargetConfig,
|
||||||
|
TargetProtocol,
|
||||||
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
)
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
DEFAULT_CHECKPOINT_DIR,
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
@ -70,6 +76,7 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig,
|
outputs_config: OutputsConfig,
|
||||||
logging_config: AppLoggingConfig,
|
logging_config: AppLoggingConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
@ -86,6 +93,7 @@ class BatDetect2API:
|
|||||||
self.outputs_config = outputs_config
|
self.outputs_config = outputs_config
|
||||||
self.logging_config = logging_config
|
self.logging_config = logging_config
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
self.roi_mapper = roi_mapper
|
||||||
self.audio_loader = audio_loader
|
self.audio_loader = audio_loader
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
@ -125,6 +133,7 @@ class BatDetect2API:
|
|||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
|
roi_mapper=self.roi_mapper,
|
||||||
model_config=model_config or self.model_config,
|
model_config=model_config or self.model_config,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
@ -171,6 +180,7 @@ class BatDetect2API:
|
|||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
|
roi_mapper=self.roi_mapper,
|
||||||
model_config=model_config or self.model_config,
|
model_config=model_config or self.model_config,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
@ -205,6 +215,7 @@ class BatDetect2API:
|
|||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
|
roi_mapper=self.roi_mapper,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
audio_config=audio_config or self.audio_config,
|
audio_config=audio_config or self.audio_config,
|
||||||
@ -391,6 +402,7 @@ class BatDetect2API:
|
|||||||
self.model,
|
self.model,
|
||||||
audio_files,
|
audio_files,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
|
roi_mapper=self.roi_mapper,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
output_transform=self.output_transform,
|
output_transform=self.output_transform,
|
||||||
@ -416,6 +428,7 @@ class BatDetect2API:
|
|||||||
self.model,
|
self.model,
|
||||||
clips,
|
clips,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
|
roi_mapper=self.roi_mapper,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
output_transform=self.output_transform,
|
output_transform=self.output_transform,
|
||||||
@ -472,6 +485,7 @@ class BatDetect2API:
|
|||||||
config: BatDetect2Config,
|
config: BatDetect2Config,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
targets = build_targets(config=config.model.targets)
|
targets = build_targets(config=config.model.targets)
|
||||||
|
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
@ -492,11 +506,13 @@ class BatDetect2API:
|
|||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=config.outputs.transform,
|
config=config.outputs.transform,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(
|
||||||
config=config.evaluation,
|
config=config.evaluation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
transform=output_transform,
|
transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -504,7 +520,8 @@ class BatDetect2API:
|
|||||||
# to avoid device mismatch errors
|
# to avoid device mismatch errors
|
||||||
model = build_model(
|
model = build_model(
|
||||||
config=config.model,
|
config=config.model,
|
||||||
targets=build_targets(config=config.model.targets),
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
preprocessor=build_preprocessor(
|
preprocessor=build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.model.preprocess,
|
config=config.model.preprocess,
|
||||||
@ -524,6 +541,7 @@ class BatDetect2API:
|
|||||||
outputs_config=config.outputs,
|
outputs_config=config.outputs,
|
||||||
logging_config=config.logging,
|
logging_config=config.logging,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
@ -561,15 +579,18 @@ class BatDetect2API:
|
|||||||
and targets_config != model_config.targets
|
and targets_config != model_config.targets
|
||||||
):
|
):
|
||||||
targets = build_targets(config=targets_config)
|
targets = build_targets(config=targets_config)
|
||||||
|
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||||
model = build_model_with_new_targets(
|
model = build_model_with_new_targets(
|
||||||
model=model,
|
model=model,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
model_config = model_config.model_copy(
|
model_config = model_config.model_copy(
|
||||||
update={"targets": targets_config}
|
update={"targets": targets_config}
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(config=model_config.targets)
|
targets = build_targets(config=model_config.targets)
|
||||||
|
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=audio_config)
|
audio_loader = build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
@ -591,11 +612,13 @@ class BatDetect2API:
|
|||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=outputs_config.transform,
|
config=outputs_config.transform,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(
|
||||||
config=evaluation_config,
|
config=evaluation_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
transform=output_transform,
|
transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -608,6 +631,7 @@ class BatDetect2API:
|
|||||||
outputs_config=outputs_config,
|
outputs_config=outputs_config,
|
||||||
logging_config=logging_config,
|
logging_config=logging_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from batdetect2.outputs import OutputsConfig, build_output_transform
|
|||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||||
|
|
||||||
@ -25,6 +25,7 @@ def run_evaluate(
|
|||||||
model: Model,
|
model: Model,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
@ -46,6 +47,7 @@ def run_evaluate(
|
|||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or model.preprocessor
|
||||||
targets = targets or model.targets
|
targets = targets or model.targets
|
||||||
|
roi_mapper = roi_mapper or model.roi_mapper
|
||||||
|
|
||||||
loader = build_test_loader(
|
loader = build_test_loader(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -57,6 +59,7 @@ def run_evaluate(
|
|||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=output_config.transform,
|
config=output_config.transform,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(
|
||||||
config=evaluation_config,
|
config=evaluation_config,
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from batdetect2.evaluate.tasks import build_task
|
|||||||
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
|
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
@ -67,17 +67,23 @@ class Evaluator:
|
|||||||
def build_evaluator(
|
def build_evaluator(
|
||||||
config: EvaluationConfig | dict | None = None,
|
config: EvaluationConfig | dict | None = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
transform: OutputTransformProtocol | None = None,
|
transform: OutputTransformProtocol | None = None,
|
||||||
) -> EvaluatorProtocol:
|
) -> EvaluatorProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
|
|
||||||
|
roi_mapper = roi_mapper or build_roi_mapping()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = EvaluationConfig()
|
config = EvaluationConfig()
|
||||||
|
|
||||||
if not isinstance(config, EvaluationConfig):
|
if not isinstance(config, EvaluationConfig):
|
||||||
config = EvaluationConfig.model_validate(config)
|
config = EvaluationConfig.model_validate(config)
|
||||||
|
|
||||||
transform = transform or build_output_transform(targets=targets)
|
transform = transform or build_output_transform(
|
||||||
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
)
|
||||||
|
|
||||||
return Evaluator(
|
return Evaluator(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
|||||||
@ -18,13 +18,14 @@ from batdetect2.outputs import (
|
|||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def run_batch_inference(
|
def run_batch_inference(
|
||||||
model: Model,
|
model: Model,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
@ -45,10 +46,12 @@ def run_batch_inference(
|
|||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or model.preprocessor
|
||||||
targets = targets or model.targets
|
targets = targets or model.targets
|
||||||
|
roi_mapper = roi_mapper or model.roi_mapper
|
||||||
|
|
||||||
output_transform = output_transform or build_output_transform(
|
output_transform = output_transform or build_output_transform(
|
||||||
config=output_config.transform,
|
config=output_config.transform,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
loader = build_inference_loader(
|
loader = build_inference_loader(
|
||||||
@ -78,6 +81,7 @@ def process_file_list(
|
|||||||
model: Model,
|
model: Model,
|
||||||
paths: Sequence[data.PathLike],
|
paths: Sequence[data.PathLike],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
@ -101,6 +105,7 @@ def process_file_list(
|
|||||||
model,
|
model,
|
||||||
clips,
|
clips,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|||||||
@ -20,7 +20,8 @@ class InferenceModule(LightningModule):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.detection_threshold = detection_threshold
|
self.detection_threshold = detection_threshold
|
||||||
self.output_transform = output_transform or build_output_transform(
|
self.output_transform = output_transform or build_output_transform(
|
||||||
targets=model.targets
|
targets=model.targets,
|
||||||
|
roi_mapper=model.roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
|
|||||||
@ -74,7 +74,7 @@ from batdetect2.postprocess.types import (
|
|||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.config import TargetConfig
|
from batdetect2.targets.config import TargetConfig
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
@ -186,12 +186,15 @@ class Model(torch.nn.Module):
|
|||||||
targets : TargetProtocol
|
targets : TargetProtocol
|
||||||
Describes the set of target classes; used when building heads and
|
Describes the set of target classes; used when building heads and
|
||||||
during training target construction.
|
during training target construction.
|
||||||
|
roi_mapper : ROIMapperProtocol
|
||||||
|
Maps geometries to target-size channels and back.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
roi_mapper: ROIMapperProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -199,12 +202,14 @@ class Model(torch.nn.Module):
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
self.roi_mapper = roi_mapper
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||||
"""Run the full detection pipeline on a waveform tensor.
|
"""Run the full detection pipeline on a waveform tensor.
|
||||||
@ -234,6 +239,7 @@ class Model(torch.nn.Module):
|
|||||||
def build_model(
|
def build_model(
|
||||||
config: ModelConfig | None = None,
|
config: ModelConfig | None = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
@ -272,10 +278,19 @@ def build_model(
|
|||||||
"""
|
"""
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
|
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
targets = targets or build_targets(config=config.targets)
|
targets = targets or build_targets(config=config.targets)
|
||||||
|
|
||||||
|
targets_config = getattr(targets, "config", None)
|
||||||
|
roi_config = (
|
||||||
|
targets_config.roi
|
||||||
|
if isinstance(targets_config, TargetConfig)
|
||||||
|
else config.targets.roi
|
||||||
|
)
|
||||||
|
|
||||||
|
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
config=config.preprocess,
|
config=config.preprocess,
|
||||||
input_samplerate=config.samplerate,
|
input_samplerate=config.samplerate,
|
||||||
@ -286,6 +301,7 @@ def build_model(
|
|||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
|
num_sizes=len(roi_mapper.dimension_names),
|
||||||
config=config.architecture,
|
config=config.architecture,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
@ -293,16 +309,19 @@ def build_model(
|
|||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_model_with_new_targets(
|
def build_model_with_new_targets(
|
||||||
model: Model,
|
model: Model,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
"""Build a new model with a different target set."""
|
"""Build a new model with a different target set."""
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
|
num_sizes=len(roi_mapper.dimension_names),
|
||||||
backbone=model.detector.backbone,
|
backbone=model.detector.backbone,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -311,4 +330,5 @@ def build_model_with_new_targets(
|
|||||||
postprocessor=model.postprocessor,
|
postprocessor=model.postprocessor,
|
||||||
preprocessor=model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -136,6 +136,7 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
def build_detector(
|
def build_detector(
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
|
num_sizes: int = 2,
|
||||||
config: BackboneConfig | None = None,
|
config: BackboneConfig | None = None,
|
||||||
backbone: BackboneModel | None = None,
|
backbone: BackboneModel | None = None,
|
||||||
) -> DetectionModel:
|
) -> DetectionModel:
|
||||||
@ -181,6 +182,7 @@ def build_detector(
|
|||||||
)
|
)
|
||||||
bbox_head = BBoxHead(
|
bbox_head = BBoxHead(
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
|
num_sizes=num_sizes,
|
||||||
)
|
)
|
||||||
return Detector(
|
return Detector(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
|
|||||||
@ -165,14 +165,15 @@ class BBoxHead(nn.Module):
|
|||||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int):
|
def __init__(self, in_channels: int, num_sizes: int = 2):
|
||||||
"""Initialise the BBoxHead."""
|
"""Initialise the BBoxHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
self.num_sizes = num_sizes
|
||||||
|
|
||||||
self.bbox = nn.Conv2d(
|
self.bbox = nn.Conv2d(
|
||||||
in_channels=self.in_channels,
|
in_channels=self.in_channels,
|
||||||
out_channels=2,
|
out_channels=self.num_sizes,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from batdetect2.postprocess.types import (
|
|||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
Detection,
|
Detection,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClipDetectionsTransformConfig",
|
"ClipDetectionsTransformConfig",
|
||||||
@ -55,10 +55,12 @@ class OutputTransform(OutputTransformProtocol):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
detection_transform_steps: Sequence[DetectionTransform] = (),
|
detection_transform_steps: Sequence[DetectionTransform] = (),
|
||||||
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
|
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
self.roi_mapper = roi_mapper
|
||||||
self.detection_transform_steps = list(detection_transform_steps)
|
self.detection_transform_steps = list(detection_transform_steps)
|
||||||
self.clip_transform_steps = list(clip_transform_steps)
|
self.clip_transform_steps = list(clip_transform_steps)
|
||||||
|
|
||||||
@ -89,7 +91,11 @@ class OutputTransform(OutputTransformProtocol):
|
|||||||
detections: ClipDetectionsTensor,
|
detections: ClipDetectionsTensor,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
) -> list[Detection]:
|
) -> list[Detection]:
|
||||||
decoded = to_detections(detections.numpy(), targets=self.targets)
|
decoded = to_detections(
|
||||||
|
detections.numpy(),
|
||||||
|
targets=self.targets,
|
||||||
|
roi_mapper=self.roi_mapper,
|
||||||
|
)
|
||||||
shifted = shift_detections_to_start_time(
|
shifted = shift_detections_to_start_time(
|
||||||
decoded,
|
decoded,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
@ -151,8 +157,9 @@ class OutputTransform(OutputTransformProtocol):
|
|||||||
def build_output_transform(
|
def build_output_transform(
|
||||||
config: OutputTransformConfig | dict | None = None,
|
config: OutputTransformConfig | dict | None = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
) -> OutputTransformProtocol:
|
) -> OutputTransformProtocol:
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = OutputTransformConfig()
|
config = OutputTransformConfig()
|
||||||
@ -161,9 +168,11 @@ def build_output_transform(
|
|||||||
config = OutputTransformConfig.model_validate(config)
|
config = OutputTransformConfig.model_validate(config)
|
||||||
|
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
|
roi_mapper = roi_mapper or build_roi_mapping()
|
||||||
|
|
||||||
return OutputTransform(
|
return OutputTransform(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
detection_transform_steps=[
|
detection_transform_steps=[
|
||||||
detection_transform_registry.build(transform_config)
|
detection_transform_registry.build(transform_config)
|
||||||
for transform_config in config.detection_transforms
|
for transform_config in config.detection_transforms
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import ClipDetectionsArray, Detection
|
from batdetect2.postprocess.types import ClipDetectionsArray, Detection
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
@ -25,6 +25,7 @@ DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
|||||||
def to_detections(
|
def to_detections(
|
||||||
detections: ClipDetectionsArray,
|
detections: ClipDetectionsArray,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
) -> List[Detection]:
|
) -> List[Detection]:
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ def to_detections(
|
|||||||
):
|
):
|
||||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||||
|
|
||||||
geom = targets.decode_roi(
|
geom = roi_mapper.decode(
|
||||||
(time, freq),
|
(time, freq),
|
||||||
dims,
|
dims,
|
||||||
class_name=highest_scoring_class,
|
class_name=highest_scoring_class,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from soundevent import data, plot
|
|||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip_annotation",
|
"plot_clip_annotation",
|
||||||
@ -48,6 +48,7 @@ def plot_clip_annotation(
|
|||||||
def plot_anchor_points(
|
def plot_anchor_points(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
figsize: tuple[int, int] | None = None,
|
figsize: tuple[int, int] | None = None,
|
||||||
ax: Axes | None = None,
|
ax: Axes | None = None,
|
||||||
size: int = 1,
|
size: int = 1,
|
||||||
@ -63,7 +64,11 @@ def plot_anchor_points(
|
|||||||
if not targets.filter(sound_event):
|
if not targets.filter(sound_event):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
position, _ = targets.encode_roi(sound_event)
|
class_name = targets.encode_class(sound_event)
|
||||||
|
position, _ = roi_mapper.encode(
|
||||||
|
sound_event.sound_event,
|
||||||
|
class_name=class_name,
|
||||||
|
)
|
||||||
positions.append(position)
|
positions.append(position)
|
||||||
|
|
||||||
X, Y = zip(*positions, strict=False)
|
X, Y = zip(*positions, strict=False)
|
||||||
|
|||||||
@ -10,7 +10,10 @@ from batdetect2.targets.config import TargetConfig
|
|||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
ROIMapperConfig,
|
ROIMapperConfig,
|
||||||
|
ROIMapperProtocol,
|
||||||
|
ROIMappingConfig,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
|
build_roi_mapping,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.targets import (
|
from batdetect2.targets.targets import (
|
||||||
Targets,
|
Targets,
|
||||||
@ -30,12 +33,15 @@ from batdetect2.targets.types import (
|
|||||||
Size,
|
Size,
|
||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
|
SoundEventFilter,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnchorBBoxMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"Position",
|
"Position",
|
||||||
|
"ROIMappingConfig",
|
||||||
|
"ROIMapperProtocol",
|
||||||
"ROIMapperConfig",
|
"ROIMapperConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"Size",
|
"Size",
|
||||||
@ -46,6 +52,7 @@ __all__ = [
|
|||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"Targets",
|
"Targets",
|
||||||
|
"build_roi_mapping",
|
||||||
"build_roi_mapper",
|
"build_roi_mapper",
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from batdetect2.data.conditions import (
|
|||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
build_sound_event_condition,
|
build_sound_event_condition,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import ROIMapperConfig
|
|
||||||
from batdetect2.targets.terms import call_type, generic_class
|
from batdetect2.targets.terms import call_type, generic_class
|
||||||
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
|
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
|
||||||
|
|
||||||
@ -39,8 +38,6 @@ class TargetClassConfig(BaseConfig):
|
|||||||
|
|
||||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||||
|
|
||||||
roi: ROIMapperConfig | None = None
|
|
||||||
|
|
||||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from batdetect2.targets.classes import (
|
|||||||
DEFAULT_DETECTION_CLASS,
|
DEFAULT_DETECTION_CLASS,
|
||||||
TargetClassConfig,
|
TargetClassConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
from batdetect2.targets.rois import ROIMappingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
@ -25,7 +25,7 @@ class TargetConfig(BaseConfig):
|
|||||||
default_factory=lambda: DEFAULT_CLASSES
|
default_factory=lambda: DEFAULT_CLASSES
|
||||||
)
|
)
|
||||||
|
|
||||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
roi: ROIMappingConfig = Field(default_factory=ROIMappingConfig)
|
||||||
|
|
||||||
@field_validator("classification_targets")
|
@field_validator("classification_targets")
|
||||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||||
|
|||||||
@ -29,7 +29,12 @@ from batdetect2.core.arrays import spec_to_xarray
|
|||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.types import Position, ROITargetMapper, Size
|
from batdetect2.targets.types import (
|
||||||
|
Position,
|
||||||
|
ROIMapperProtocol,
|
||||||
|
ROITargetMapper,
|
||||||
|
Size,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Anchor",
|
"Anchor",
|
||||||
@ -40,12 +45,15 @@ __all__ = [
|
|||||||
"DEFAULT_TIME_SCALE",
|
"DEFAULT_TIME_SCALE",
|
||||||
"PeakEnergyBBoxMapper",
|
"PeakEnergyBBoxMapper",
|
||||||
"PeakEnergyBBoxMapperConfig",
|
"PeakEnergyBBoxMapperConfig",
|
||||||
|
"ROIMappingConfig",
|
||||||
|
"ROIMapperProtocol",
|
||||||
"ROIMapperConfig",
|
"ROIMapperConfig",
|
||||||
"ROIMapperImportConfig",
|
"ROIMapperImportConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"SIZE_HEIGHT",
|
"SIZE_HEIGHT",
|
||||||
"SIZE_ORDER",
|
"SIZE_ORDER",
|
||||||
"SIZE_WIDTH",
|
"SIZE_WIDTH",
|
||||||
|
"build_roi_mapping",
|
||||||
"build_roi_mapper",
|
"build_roi_mapper",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -456,6 +464,59 @@ implementations by using the `name` field as a discriminator.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ROIMappingConfig(BaseConfig):
|
||||||
|
"""Configuration for class-aware ROI mapping.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
default : ROIMapperConfig
|
||||||
|
Default mapper used when no class-specific override exists.
|
||||||
|
overrides : dict[str, ROIMapperConfig]
|
||||||
|
Optional class-specific mapper overrides by class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||||
|
overrides: dict[str, ROIMapperConfig] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassAwareROIMapper(ROIMapperProtocol):
|
||||||
|
"""Apply a default ROI mapper with optional per-class overrides."""
|
||||||
|
|
||||||
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
default_mapper: ROITargetMapper,
|
||||||
|
overrides: dict[str, ROITargetMapper] | None = None,
|
||||||
|
):
|
||||||
|
self.default_mapper = default_mapper
|
||||||
|
self.overrides = overrides or {}
|
||||||
|
self.dimension_names = list(default_mapper.dimension_names)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
class_name: str | None = None,
|
||||||
|
) -> tuple[Position, Size]:
|
||||||
|
mapper = self._select_mapper(class_name)
|
||||||
|
return mapper.encode(sound_event)
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
class_name: str | None = None,
|
||||||
|
) -> data.Geometry:
|
||||||
|
mapper = self._select_mapper(class_name)
|
||||||
|
return mapper.decode(position, size)
|
||||||
|
|
||||||
|
def _select_mapper(self, class_name: str | None = None) -> ROITargetMapper:
|
||||||
|
if class_name is not None and class_name in self.overrides:
|
||||||
|
return self.overrides[class_name]
|
||||||
|
|
||||||
|
return self.default_mapper
|
||||||
|
|
||||||
|
|
||||||
def build_roi_mapper(
|
def build_roi_mapper(
|
||||||
config: ROIMapperConfig | None = None,
|
config: ROIMapperConfig | None = None,
|
||||||
) -> ROITargetMapper:
|
) -> ROITargetMapper:
|
||||||
@ -480,6 +541,36 @@ def build_roi_mapper(
|
|||||||
return roi_mapper_registry.build(config)
|
return roi_mapper_registry.build(config)
|
||||||
|
|
||||||
|
|
||||||
|
def build_roi_mapping(
|
||||||
|
config: ROIMappingConfig | None = None,
|
||||||
|
) -> ROIMapperProtocol:
|
||||||
|
"""Build a class-aware ROI mapper and validate consistency."""
|
||||||
|
config = config or ROIMappingConfig()
|
||||||
|
|
||||||
|
default_mapper = build_roi_mapper(config.default)
|
||||||
|
overrides = {
|
||||||
|
class_name: build_roi_mapper(mapper_config)
|
||||||
|
for class_name, mapper_config in config.overrides.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = list(default_mapper.dimension_names)
|
||||||
|
|
||||||
|
for class_name, mapper in overrides.items():
|
||||||
|
actual = list(mapper.dimension_names)
|
||||||
|
|
||||||
|
if actual != expected:
|
||||||
|
raise ValueError(
|
||||||
|
"All ROI mappers must share the same dimension order. "
|
||||||
|
f"Default dimensions: {expected}, "
|
||||||
|
f"class '{class_name}' dimensions: {actual}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClassAwareROIMapper(
|
||||||
|
default_mapper=default_mapper,
|
||||||
|
overrides=overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
VALID_ANCHORS = [
|
VALID_ANCHORS = [
|
||||||
"bottom-left",
|
"bottom-left",
|
||||||
"bottom-right",
|
"bottom-right",
|
||||||
|
|||||||
@ -12,21 +12,21 @@ from batdetect2.targets.classes import (
|
|||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.config import TargetConfig
|
from batdetect2.targets.config import TargetConfig
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.types import (
|
||||||
AnchorBBoxMapperConfig,
|
Position,
|
||||||
build_roi_mapper,
|
ROIMapperProtocol,
|
||||||
|
Size,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class Targets(TargetProtocol):
|
class Targets(TargetProtocol):
|
||||||
"""Encapsulates the complete configured target definition pipeline.
|
"""Encapsulates the configured target class definition pipeline.
|
||||||
|
|
||||||
This class implements the `TargetProtocol`, holding the configured
|
This class implements the `TargetProtocol`, holding the configured
|
||||||
functions for filtering, transforming, encoding (tags to class name),
|
functions for filtering, encoding (tags to class name), and decoding
|
||||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
(class name to tags). Geometry ROI mapping is handled separately by
|
||||||
and back). It provides a high-level interface to apply these steps and
|
``ROIMapperProtocol``.
|
||||||
access relevant metadata like class names and dimension names.
|
|
||||||
|
|
||||||
Instances are typically created using the `build_targets` factory function
|
Instances are typically created using the `build_targets` factory function
|
||||||
or the `load_targets` convenience loader.
|
or the `load_targets` convenience loader.
|
||||||
@ -39,14 +39,10 @@ class Targets(TargetProtocol):
|
|||||||
generic_class_tags
|
generic_class_tags
|
||||||
A list of `soundevent.data.Tag` objects representing the configured
|
A list of `soundevent.data.Tag` objects representing the configured
|
||||||
generic class category (used when no specific class matches).
|
generic class category (used when no specific class matches).
|
||||||
dimension_names
|
|
||||||
The names of the size dimensions handled by the ROI mapper
|
|
||||||
(e.g., ['width', 'height']).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class_names: list[str]
|
class_names: list[str]
|
||||||
detection_class_tags: list[data.Tag]
|
detection_class_tags: list[data.Tag]
|
||||||
dimension_names: list[str]
|
|
||||||
detection_class_name: str
|
detection_class_name: str
|
||||||
|
|
||||||
def __init__(self, config: TargetConfig):
|
def __init__(self, config: TargetConfig):
|
||||||
@ -63,10 +59,6 @@ class Targets(TargetProtocol):
|
|||||||
config.classification_targets
|
config.classification_targets
|
||||||
)
|
)
|
||||||
|
|
||||||
self._roi_mapper = build_roi_mapper(config.roi)
|
|
||||||
|
|
||||||
self.dimension_names = self._roi_mapper.dimension_names
|
|
||||||
|
|
||||||
self.class_names = get_class_names_from_config(
|
self.class_names = get_class_names_from_config(
|
||||||
config.classification_targets
|
config.classification_targets
|
||||||
)
|
)
|
||||||
@ -74,21 +66,6 @@ class Targets(TargetProtocol):
|
|||||||
self.detection_class_name = config.detection_target.name
|
self.detection_class_name = config.detection_target.name
|
||||||
self.detection_class_tags = config.detection_target.assign_tags
|
self.detection_class_tags = config.detection_target.assign_tags
|
||||||
|
|
||||||
self._roi_mapper_overrides = {
|
|
||||||
class_config.name: build_roi_mapper(class_config.roi)
|
|
||||||
for class_config in config.classification_targets
|
|
||||||
if class_config.roi is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
for class_name in self._roi_mapper_overrides:
|
|
||||||
if class_name not in self.class_names:
|
|
||||||
# TODO: improve this warning
|
|
||||||
logger.warning(
|
|
||||||
"The ROI mapper overrides contains a class ({class_name}) "
|
|
||||||
"not present in the class names.",
|
|
||||||
class_name=class_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
"""Apply the configured filter to a sound event annotation.
|
"""Apply the configured filter to a sound event annotation.
|
||||||
|
|
||||||
@ -147,75 +124,10 @@ class Targets(TargetProtocol):
|
|||||||
"""
|
"""
|
||||||
return self._decode_fn(class_label)
|
return self._decode_fn(class_label)
|
||||||
|
|
||||||
def encode_roi(
|
|
||||||
self, sound_event: data.SoundEventAnnotation
|
|
||||||
) -> tuple[Position, Size]:
|
|
||||||
"""Extract the target reference position from the annotation's roi.
|
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation containing the geometry (ROI).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
tuple[float, float]
|
|
||||||
The reference position `(time, frequency)`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the annotation lacks geometry.
|
|
||||||
"""
|
|
||||||
class_name = self.encode_class(sound_event)
|
|
||||||
|
|
||||||
if class_name in self._roi_mapper_overrides:
|
|
||||||
return self._roi_mapper_overrides[class_name].encode(
|
|
||||||
sound_event.sound_event
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._roi_mapper.encode(sound_event.sound_event)
|
|
||||||
|
|
||||||
def decode_roi(
|
|
||||||
self,
|
|
||||||
position: Position,
|
|
||||||
size: Size,
|
|
||||||
class_name: str | None = None,
|
|
||||||
) -> data.Geometry:
|
|
||||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
|
||||||
un-scales the dimensions and reconstructs the geometry (typically a
|
|
||||||
`BoundingBox`).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pos
|
|
||||||
The reference position `(time, frequency)`.
|
|
||||||
dims
|
|
||||||
NumPy array with size dimensions (e.g., from model prediction),
|
|
||||||
matching the order in `self.dimension_names`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.Geometry
|
|
||||||
The reconstructed geometry (typically `BoundingBox`).
|
|
||||||
"""
|
|
||||||
if class_name in self._roi_mapper_overrides:
|
|
||||||
return self._roi_mapper_overrides[class_name].decode(
|
|
||||||
position,
|
|
||||||
size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._roi_mapper.decode(position, size)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||||
classification_targets=DEFAULT_CLASSES,
|
classification_targets=DEFAULT_CLASSES,
|
||||||
detection_target=DEFAULT_DETECTION_CLASS,
|
detection_target=DEFAULT_DETECTION_CLASS,
|
||||||
roi=AnchorBBoxMapperConfig(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -292,6 +204,7 @@ def load_targets(
|
|||||||
def iterate_encoded_sound_events(
|
def iterate_encoded_sound_events(
|
||||||
sound_events: Iterable[data.SoundEventAnnotation],
|
sound_events: Iterable[data.SoundEventAnnotation],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
) -> Iterable[tuple[str | None, Position, Size]]:
|
) -> Iterable[tuple[str | None, Position, Size]]:
|
||||||
for sound_event in sound_events:
|
for sound_event in sound_events:
|
||||||
if not targets.filter(sound_event):
|
if not targets.filter(sound_event):
|
||||||
@ -303,6 +216,9 @@ def iterate_encoded_sound_events(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event)
|
class_name = targets.encode_class(sound_event)
|
||||||
position, size = targets.encode_roi(sound_event)
|
position, size = roi_mapper.encode(
|
||||||
|
sound_event.sound_event,
|
||||||
|
class_name=class_name,
|
||||||
|
)
|
||||||
|
|
||||||
yield class_name, position, size
|
yield class_name, position, size
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from soundevent import data
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Position",
|
"Position",
|
||||||
|
"ROIMapperProtocol",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"Size",
|
"Size",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
@ -26,7 +27,6 @@ class TargetProtocol(Protocol):
|
|||||||
class_names: list[str]
|
class_names: list[str]
|
||||||
detection_class_tags: list[data.Tag]
|
detection_class_tags: list[data.Tag]
|
||||||
detection_class_name: str
|
detection_class_name: str
|
||||||
dimension_names: list[str]
|
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||||
|
|
||||||
@ -37,6 +37,23 @@ class TargetProtocol(Protocol):
|
|||||||
|
|
||||||
def decode_class(self, class_label: str) -> list[data.Tag]: ...
|
def decode_class(self, class_label: str) -> list[data.Tag]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ROIMapperProtocol(Protocol):
|
||||||
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
class_name: str | None = None,
|
||||||
|
) -> tuple[Position, Size]: ...
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
class_name: str | None = None,
|
||||||
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
def encode_roi(
|
def encode_roi(
|
||||||
self,
|
self,
|
||||||
sound_event: data.SoundEventAnnotation,
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
|||||||
@ -93,7 +93,8 @@ class ValidationMetrics(Callback):
|
|||||||
model = pl_module.model
|
model = pl_module.model
|
||||||
if self.output_transform is None:
|
if self.output_transform is None:
|
||||||
self.output_transform = build_output_transform(
|
self.output_transform = build_output_transform(
|
||||||
targets=model.targets
|
targets=model.targets,
|
||||||
|
roi_mapper=model.roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_transform = self.output_transform
|
output_transform = self.output_transform
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def build_checkpoint_callback(
|
|||||||
if run_name is not None:
|
if run_name is not None:
|
||||||
checkpoint_dir = checkpoint_dir / run_name
|
checkpoint_dir = checkpoint_dir / run_name
|
||||||
|
|
||||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
return ModelCheckpoint(
|
return ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
|
|||||||
@ -14,8 +14,12 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
from batdetect2.targets import (
|
||||||
from batdetect2.targets.types import TargetProtocol
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
iterate_encoded_sound_events,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
from batdetect2.train.types import ClipLabeller, Heatmaps
|
from batdetect2.train.types import ClipLabeller, Heatmaps
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -42,6 +46,7 @@ class LabelConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_clip_labeler(
|
def build_clip_labeler(
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
min_freq: float = MIN_FREQ,
|
min_freq: float = MIN_FREQ,
|
||||||
max_freq: float = MAX_FREQ,
|
max_freq: float = MAX_FREQ,
|
||||||
config: LabelConfig | None = None,
|
config: LabelConfig | None = None,
|
||||||
@ -53,12 +58,13 @@ def build_clip_labeler(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if targets is None:
|
targets = targets or build_targets()
|
||||||
targets = build_targets()
|
roi_mapper = roi_mapper or build_roi_mapping()
|
||||||
|
|
||||||
return partial(
|
return partial(
|
||||||
generate_heatmaps,
|
generate_heatmaps,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
max_freq=max_freq,
|
max_freq=max_freq,
|
||||||
target_sigma=config.sigma,
|
target_sigma=config.sigma,
|
||||||
@ -73,6 +79,7 @@ def generate_heatmaps(
|
|||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
min_freq: float,
|
min_freq: float,
|
||||||
max_freq: float,
|
max_freq: float,
|
||||||
target_sigma: float = 3.0,
|
target_sigma: float = 3.0,
|
||||||
@ -89,7 +96,7 @@ def generate_heatmaps(
|
|||||||
height = spec.shape[-2]
|
height = spec.shape[-2]
|
||||||
width = spec.shape[-1]
|
width = spec.shape[-1]
|
||||||
num_classes = len(targets.class_names)
|
num_classes = len(targets.class_names)
|
||||||
num_dims = len(targets.dimension_names)
|
num_dims = len(roi_mapper.dimension_names)
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
@ -109,6 +116,7 @@ def generate_heatmaps(
|
|||||||
for class_name, (time, frequency), size in iterate_encoded_sound_events(
|
for class_name, (time, frequency), size in iterate_encoded_sound_events(
|
||||||
clip_annotation.sound_events,
|
clip_annotation.sound_events,
|
||||||
targets,
|
targets,
|
||||||
|
roi_mapper,
|
||||||
):
|
):
|
||||||
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
|
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
|
||||||
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
|
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
|
||||||
|
|||||||
@ -6,23 +6,24 @@ from lightning import Trainer, seed_everything
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
||||||
from batdetect2.evaluate import build_evaluator
|
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
|
||||||
from batdetect2.logging import (
|
from batdetect2.logging import (
|
||||||
LoggerConfig,
|
LoggerConfig,
|
||||||
TensorBoardLoggerConfig,
|
TensorBoardLoggerConfig,
|
||||||
build_logger,
|
build_logger,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.targets import (
|
||||||
from batdetect2.targets import build_targets
|
ROIMapperProtocol,
|
||||||
from batdetect2.targets.types import TargetProtocol
|
TargetProtocol,
|
||||||
from batdetect2.train import TrainingConfig
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
)
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||||
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
@ -39,6 +40,7 @@ def run_train(
|
|||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
model: Model | None = None,
|
model: Model | None = None,
|
||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
|
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
labeller: Optional["ClipLabeller"] = None,
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
@ -69,8 +71,15 @@ def run_train(
|
|||||||
if model is not None:
|
if model is not None:
|
||||||
targets = targets or model.targets
|
targets = targets or model.targets
|
||||||
|
|
||||||
|
if roi_mapper is None and targets is model.targets:
|
||||||
|
roi_mapper = model.roi_mapper
|
||||||
|
|
||||||
targets = targets or build_targets(config=model_config.targets)
|
targets = targets or build_targets(config=model_config.targets)
|
||||||
|
|
||||||
|
roi_mapper = roi_mapper or build_roi_mapping(
|
||||||
|
config=model_config.targets.roi
|
||||||
|
)
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
@ -80,6 +89,7 @@ def run_train(
|
|||||||
|
|
||||||
labeller = labeller or build_clip_labeler(
|
labeller = labeller or build_clip_labeler(
|
||||||
targets,
|
targets,
|
||||||
|
roi_mapper,
|
||||||
min_freq=preprocessor.min_freq,
|
min_freq=preprocessor.min_freq,
|
||||||
max_freq=preprocessor.max_freq,
|
max_freq=preprocessor.max_freq,
|
||||||
config=train_config.labels,
|
config=train_config.labels,
|
||||||
@ -119,6 +129,7 @@ def run_train(
|
|||||||
evaluator=build_evaluator(
|
evaluator=build_evaluator(
|
||||||
train_config.validation,
|
train_config.validation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
),
|
),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from hypothesis import given
|
from hypothesis import given, settings
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
@ -9,6 +9,7 @@ from batdetect2.utils import audio_utils, detector_utils
|
|||||||
|
|
||||||
|
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=1))
|
@given(duration=st.floats(min_value=0.1, max_value=1))
|
||||||
|
@settings(deadline=None)
|
||||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
@ -87,6 +88,7 @@ def test_pad_audio_without_fixed_size(duration: float):
|
|||||||
|
|
||||||
|
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
|
@settings(deadline=None)
|
||||||
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||||
duration: float,
|
duration: float,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -8,15 +8,6 @@ from click.testing import CliRunner
|
|||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_help() -> None:
|
|
||||||
"""User story: get usage help for legacy detect command."""
|
|
||||||
|
|
||||||
result = CliRunner().invoke(cli, ["detect", "--help"])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
||||||
"""User story: run legacy detect on example audio directory."""
|
"""User story: run legacy detect on example audio directory."""
|
||||||
|
|
||||||
|
|||||||
@ -41,6 +41,22 @@ def test_build_detector_custom_config():
|
|||||||
assert model.backbone.encoder.in_channels == 2
|
assert model.backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_detector_custom_size_channels():
|
||||||
|
num_classes = 3
|
||||||
|
num_sizes = 4
|
||||||
|
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
||||||
|
|
||||||
|
model = build_detector(
|
||||||
|
num_classes=num_classes,
|
||||||
|
num_sizes=num_sizes,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy = torch.randn(1, 1, 128, 64)
|
||||||
|
output = model(dummy)
|
||||||
|
assert output.size_preds.shape[1] == num_sizes
|
||||||
|
|
||||||
|
|
||||||
def test_detector_forward_pass_shapes(dummy_spectrogram):
|
def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||||
"""Test that the forward pass produces correctly shaped outputs."""
|
"""Test that the forward pass produces correctly shaped outputs."""
|
||||||
num_classes = 4
|
num_classes = 4
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from soundevent.geometry import compute_bounds
|
|||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.outputs import build_output_transform
|
from batdetect2.outputs import build_output_transform
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.targets import build_roi_mapping
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
|
|
||||||
@ -37,7 +38,9 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
|
|||||||
width = int(duration * sample_preprocessor.output_samplerate)
|
width = int(duration * sample_preprocessor.output_samplerate)
|
||||||
spec = torch.zeros((1, height, width), dtype=torch.float32)
|
spec = torch.zeros((1, height, width), dtype=torch.float32)
|
||||||
|
|
||||||
labeler = build_clip_labeler(targets=sample_targets)
|
roi_mapper = build_roi_mapping()
|
||||||
|
|
||||||
|
labeler = build_clip_labeler(targets=sample_targets, roi_mapper=roi_mapper)
|
||||||
heatmaps = labeler(clip_annotation, spec)
|
heatmaps = labeler(clip_annotation, spec)
|
||||||
|
|
||||||
output = ModelOutput(
|
output = ModelOutput(
|
||||||
@ -51,7 +54,10 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
|
|||||||
clip_detection_tensors = postprocessor(output)
|
clip_detection_tensors = postprocessor(output)
|
||||||
assert len(clip_detection_tensors) == 1
|
assert len(clip_detection_tensors) == 1
|
||||||
|
|
||||||
transform = build_output_transform(targets=sample_targets)
|
transform = build_output_transform(
|
||||||
|
targets=sample_targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
)
|
||||||
clip_detections = transform.to_clip_detections(
|
clip_detections = transform.to_clip_detections(
|
||||||
detections=clip_detection_tensors[0],
|
detections=clip_detection_tensors[0],
|
||||||
clip=clip,
|
clip=clip,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from batdetect2.postprocess.types import (
|
|||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
Detection,
|
Detection,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets import TargetConfig, build_roi_mapping
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
@ -27,9 +28,22 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_roi_mapper(targets: TargetProtocol):
|
||||||
|
config_obj = getattr(targets, "config", None)
|
||||||
|
target_config = (
|
||||||
|
config_obj if isinstance(config_obj, TargetConfig) else None
|
||||||
|
)
|
||||||
|
return build_roi_mapping(
|
||||||
|
config=(target_config.roi if target_config is not None else None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
|
def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
|
||||||
raw = _mock_clip_detections_tensor()
|
raw = _mock_clip_detections_tensor()
|
||||||
transform = build_output_transform(targets=sample_targets)
|
transform = build_output_transform(
|
||||||
|
targets=sample_targets,
|
||||||
|
roi_mapper=_build_roi_mapper(sample_targets),
|
||||||
|
)
|
||||||
|
|
||||||
transformed = transform.to_detections(raw, start_time=2.5)
|
transformed = transform.to_detections(raw, start_time=2.5)
|
||||||
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
|
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
|
||||||
@ -43,7 +57,10 @@ def test_to_clip_detections_shifts_by_clip_start(
|
|||||||
sample_targets: TargetProtocol,
|
sample_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||||
transform = build_output_transform(targets=sample_targets)
|
transform = build_output_transform(
|
||||||
|
targets=sample_targets,
|
||||||
|
roi_mapper=_build_roi_mapper(sample_targets),
|
||||||
|
)
|
||||||
raw = _mock_clip_detections_tensor()
|
raw = _mock_clip_detections_tensor()
|
||||||
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
||||||
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
|
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
|
||||||
@ -90,6 +107,7 @@ def test_detection_and_clip_transforms_applied_in_order(
|
|||||||
|
|
||||||
transform = OutputTransform(
|
transform = OutputTransform(
|
||||||
targets=sample_targets,
|
targets=sample_targets,
|
||||||
|
roi_mapper=_build_roi_mapper(sample_targets),
|
||||||
detection_transform_steps=[boost_score, keep_high_score],
|
detection_transform_steps=[boost_score, keep_high_score],
|
||||||
clip_transform_steps=[tag_clip_transform],
|
clip_transform_steps=[tag_clip_transform],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@ -22,8 +24,10 @@ from batdetect2.targets.rois import (
|
|||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
PeakEnergyBBoxMapper,
|
PeakEnergyBBoxMapper,
|
||||||
PeakEnergyBBoxMapperConfig,
|
PeakEnergyBBoxMapperConfig,
|
||||||
|
ROIMappingConfig,
|
||||||
_build_bounding_box,
|
_build_bounding_box,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
|
build_roi_mapping,
|
||||||
get_peak_energy_coordinates,
|
get_peak_energy_coordinates,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -630,3 +634,43 @@ def test_build_roi_mapper_raises_error_for_unknown_name():
|
|||||||
# Then
|
# Then
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
build_roi_mapper(DummyConfig()) # type: ignore
|
build_roi_mapper(DummyConfig()) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_roi_mapping_applies_class_override():
|
||||||
|
config = ROIMappingConfig(
|
||||||
|
default=AnchorBBoxMapperConfig(anchor="bottom-left"),
|
||||||
|
overrides={
|
||||||
|
"myomyo": AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
mapper = build_roi_mapping(config=config)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
sound_event = data.SoundEvent(
|
||||||
|
recording=data.Recording(
|
||||||
|
path=Path("x.wav"),
|
||||||
|
samplerate=256_000,
|
||||||
|
channels=1,
|
||||||
|
duration=1.0,
|
||||||
|
),
|
||||||
|
geometry=geometry,
|
||||||
|
)
|
||||||
|
|
||||||
|
default_position, _ = mapper.encode(sound_event, class_name="pippip")
|
||||||
|
override_position, _ = mapper.encode(sound_event, class_name="myomyo")
|
||||||
|
|
||||||
|
assert default_position == pytest.approx((0.1, 12_000))
|
||||||
|
assert override_position == pytest.approx((0.1, 18_000))
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_roi_mapping_rejects_dimension_mismatch():
|
||||||
|
config = ROIMappingConfig(
|
||||||
|
default=AnchorBBoxMapperConfig(),
|
||||||
|
overrides={
|
||||||
|
"myomyo": PeakEnergyBBoxMapperConfig(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="same dimension order"):
|
||||||
|
build_roi_mapping(config=config)
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||||
|
|
||||||
|
|
||||||
def test_can_override_default_roi_mapper_per_class(
|
def test_can_override_default_roi_mapper_per_class(
|
||||||
@ -32,18 +33,21 @@ def test_can_override_default_roi_mapper_per_class(
|
|||||||
tags:
|
tags:
|
||||||
- key: species
|
- key: species
|
||||||
value: Myotis myotis
|
value: Myotis myotis
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
|
|
||||||
roi:
|
roi:
|
||||||
name: anchor_bbox
|
default:
|
||||||
anchor: bottom-left
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
overrides:
|
||||||
|
myomyo:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
"""
|
"""
|
||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = TargetConfig.load(config_path)
|
config = TargetConfig.load(config_path)
|
||||||
targets = build_targets(config)
|
targets = build_targets(config)
|
||||||
|
roi_mapper = build_roi_mapping(config=config.roi)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
@ -60,8 +64,17 @@ def test_can_override_default_roi_mapper_per_class(
|
|||||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
|
|
||||||
(time1, freq1), _ = targets.encode_roi(se1)
|
class_name1 = targets.encode_class(se1)
|
||||||
(time2, freq2), _ = targets.encode_roi(se2)
|
class_name2 = targets.encode_class(se2)
|
||||||
|
|
||||||
|
(time1, freq1), _ = roi_mapper.encode(
|
||||||
|
se1.sound_event,
|
||||||
|
class_name=class_name1,
|
||||||
|
)
|
||||||
|
(time2, freq2), _ = roi_mapper.encode(
|
||||||
|
se2.sound_event,
|
||||||
|
class_name=class_name2,
|
||||||
|
)
|
||||||
|
|
||||||
assert time1 == time2 == 0.1
|
assert time1 == time2 == 0.1
|
||||||
assert freq1 == 12_000
|
assert freq1 == 12_000
|
||||||
@ -95,18 +108,21 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
|
|||||||
tags:
|
tags:
|
||||||
- key: species
|
- key: species
|
||||||
value: Myotis myotis
|
value: Myotis myotis
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
|
|
||||||
roi:
|
roi:
|
||||||
name: anchor_bbox
|
default:
|
||||||
anchor: bottom-left
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
overrides:
|
||||||
|
myomyo:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
"""
|
"""
|
||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = TargetConfig.load(config_path)
|
config = TargetConfig.load(config_path)
|
||||||
targets = build_targets(config)
|
targets = build_targets(config)
|
||||||
|
roi_mapper = build_roi_mapping(config=config.roi)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
@ -122,14 +138,14 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
|
|||||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||||
)
|
)
|
||||||
|
|
||||||
position1, size1 = targets.encode_roi(se1)
|
position1, size1 = roi_mapper.encode(se1.sound_event, class_name="pippip")
|
||||||
position2, size2 = targets.encode_roi(se2)
|
position2, size2 = roi_mapper.encode(se2.sound_event, class_name="myomyo")
|
||||||
|
|
||||||
class_name1 = targets.encode_class(se1)
|
class_name1 = targets.encode_class(se1)
|
||||||
class_name2 = targets.encode_class(se2)
|
class_name2 = targets.encode_class(se2)
|
||||||
|
|
||||||
recovered1 = targets.decode_roi(position1, size1, class_name=class_name1)
|
recovered1 = roi_mapper.decode(position1, size1, class_name=class_name1)
|
||||||
recovered2 = targets.decode_roi(position2, size2, class_name=class_name2)
|
recovered2 = roi_mapper.decode(position2, size2, class_name=class_name2)
|
||||||
|
|
||||||
assert recovered1 == geometry
|
assert recovered1 == geometry
|
||||||
assert recovered2 == geometry
|
assert recovered2 == geometry
|
||||||
|
|||||||
@ -42,28 +42,6 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
|
|||||||
assert checkpoints
|
assert checkpoints
|
||||||
|
|
||||||
|
|
||||||
def test_train_without_validation_does_not_save_default_monitored_checkpoint(
|
|
||||||
tmp_path: Path,
|
|
||||||
example_annotations: list[data.ClipAnnotation],
|
|
||||||
) -> None:
|
|
||||||
config = _build_fast_train_config()
|
|
||||||
|
|
||||||
run_train(
|
|
||||||
train_annotations=example_annotations[:1],
|
|
||||||
val_annotations=None,
|
|
||||||
train_config=config.train,
|
|
||||||
model_config=config.model,
|
|
||||||
audio_config=config.audio,
|
|
||||||
num_epochs=1,
|
|
||||||
train_workers=0,
|
|
||||||
val_workers=0,
|
|
||||||
checkpoint_dir=tmp_path,
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert not list(tmp_path.rglob("*.ckpt"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_without_validation_can_still_save_last_checkpoint(
|
def test_train_without_validation_can_still_save_last_checkpoint(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
|
|||||||
@ -3,8 +3,8 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMappingConfig
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
@ -30,14 +30,17 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
):
|
):
|
||||||
config = sample_target_config.model_copy(
|
config = sample_target_config.model_copy(
|
||||||
update=dict(
|
update=dict(
|
||||||
roi=AnchorBBoxMapperConfig(
|
roi=ROIMappingConfig(
|
||||||
time_scale=1,
|
default=AnchorBBoxMapperConfig(
|
||||||
frequency_scale=1,
|
time_scale=1,
|
||||||
|
frequency_scale=1,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(config)
|
targets = build_targets(config)
|
||||||
|
roi_mapper = build_roi_mapping(config=config.roi)
|
||||||
|
|
||||||
clip_annotation = data.ClipAnnotation(
|
clip_annotation = data.ClipAnnotation(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
@ -60,6 +63,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
min_freq=0,
|
min_freq=0,
|
||||||
max_freq=100,
|
max_freq=100,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
pippip_index = targets.class_names.index("pippip")
|
pippip_index = targets.class_names.index("pippip")
|
||||||
myomyo_index = targets.class_names.index("myomyo")
|
myomyo_index = targets.class_names.index("myomyo")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user