mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Allow building new head
This commit is contained in:
parent
8e35956007
commit
ebe7e134e9
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
from typing import Literal, Sequence, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -19,7 +19,8 @@ from batdetect2.evaluate import (
|
|||||||
)
|
)
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model, build_model_with_new_targets
|
||||||
|
from batdetect2.models.detectors import Detector
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputFormatConfig,
|
OutputFormatConfig,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
@ -109,6 +110,46 @@ class BatDetect2API:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def finetune(
|
||||||
|
self,
|
||||||
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
|
trainable: Literal[
|
||||||
|
"all", "heads", "classifier_head", "bbox_head"
|
||||||
|
] = "heads",
|
||||||
|
train_workers: int = 0,
|
||||||
|
val_workers: int = 0,
|
||||||
|
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||||
|
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||||
|
experiment_name: str | None = None,
|
||||||
|
num_epochs: int | None = None,
|
||||||
|
run_name: str | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
) -> "BatDetect2API":
|
||||||
|
"""Fine-tune the model with trainable-parameter selection."""
|
||||||
|
|
||||||
|
self._set_trainable_parameters(trainable)
|
||||||
|
|
||||||
|
run_train(
|
||||||
|
train_annotations=train_annotations,
|
||||||
|
val_annotations=val_annotations,
|
||||||
|
model=self.model,
|
||||||
|
targets=self.targets,
|
||||||
|
model_config=self.config.model,
|
||||||
|
train_config=self.config.train,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
train_workers=train_workers,
|
||||||
|
val_workers=val_workers,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
num_epochs=num_epochs,
|
||||||
|
run_name=run_name,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
@ -410,6 +451,7 @@ class BatDetect2API:
|
|||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: BatDetect2Config | None = None,
|
config: BatDetect2Config | None = None,
|
||||||
|
targets: TargetProtocol | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
|
|
||||||
@ -423,7 +465,12 @@ class BatDetect2API:
|
|||||||
)
|
)
|
||||||
config = merge_configs(base, config) if config else base
|
config = merge_configs(base, config) if config else base
|
||||||
|
|
||||||
|
if targets is None:
|
||||||
targets = build_targets(config=config.model.targets)
|
targets = build_targets(config=config.model.targets)
|
||||||
|
else:
|
||||||
|
target_config = getattr(targets, "config", None)
|
||||||
|
if target_config is not None:
|
||||||
|
config.model.targets = target_config
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
@ -452,6 +499,16 @@ class BatDetect2API:
|
|||||||
transform=output_transform,
|
transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
targets_changed = targets is not None or (
|
||||||
|
config.model.targets.model_dump(mode="json")
|
||||||
|
!= model_config.targets.model_dump(mode="json")
|
||||||
|
)
|
||||||
|
if targets_changed:
|
||||||
|
model = build_model_with_new_targets(
|
||||||
|
model=model,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
model.preprocessor = preprocessor
|
model.preprocessor = preprocessor
|
||||||
model.postprocessor = postprocessor
|
model.postprocessor = postprocessor
|
||||||
model.targets = targets
|
model.targets = targets
|
||||||
@ -467,3 +524,25 @@ class BatDetect2API:
|
|||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
output_transform=output_transform,
|
output_transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _set_trainable_parameters(
|
||||||
|
self,
|
||||||
|
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||||
|
) -> None:
|
||||||
|
detector = cast(Detector, self.model.detector)
|
||||||
|
|
||||||
|
for parameter in detector.parameters():
|
||||||
|
parameter.requires_grad = False
|
||||||
|
|
||||||
|
if trainable == "all":
|
||||||
|
for parameter in detector.parameters():
|
||||||
|
parameter.requires_grad = True
|
||||||
|
return
|
||||||
|
|
||||||
|
if trainable in {"heads", "classifier_head"}:
|
||||||
|
for parameter in detector.classifier_head.parameters():
|
||||||
|
parameter.requires_grad = True
|
||||||
|
|
||||||
|
if trainable in {"heads", "bbox_head"}:
|
||||||
|
for parameter in detector.bbox_head.parameters():
|
||||||
|
parameter.requires_grad = True
|
||||||
|
|||||||
@ -100,6 +100,7 @@ __all__ = [
|
|||||||
"Model",
|
"Model",
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"build_model",
|
"build_model",
|
||||||
|
"build_model_with_new_targets",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -274,3 +275,21 @@ def build_model(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_with_new_targets(
|
||||||
|
model: Model,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> Model:
|
||||||
|
"""Build a new model with a different target set."""
|
||||||
|
detector = build_detector(
|
||||||
|
num_classes=len(targets.class_names),
|
||||||
|
backbone=model.detector.backbone,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Model(
|
||||||
|
detector=detector,
|
||||||
|
postprocessor=model.postprocessor,
|
||||||
|
preprocessor=model.preprocessor,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|||||||
@ -135,7 +135,9 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
|
|
||||||
def build_detector(
|
def build_detector(
|
||||||
num_classes: int, config: BackboneConfig | None = None
|
num_classes: int,
|
||||||
|
config: BackboneConfig | None = None,
|
||||||
|
backbone: BackboneModel | None = None,
|
||||||
) -> DetectionModel:
|
) -> DetectionModel:
|
||||||
"""Build a complete BatDetect2 detection model.
|
"""Build a complete BatDetect2 detection model.
|
||||||
|
|
||||||
@ -165,13 +167,14 @@ def build_detector(
|
|||||||
If ``num_classes`` is not positive, or if the backbone
|
If ``num_classes`` is not positive, or if the backbone
|
||||||
configuration is invalid.
|
configuration is invalid.
|
||||||
"""
|
"""
|
||||||
|
if backbone is None:
|
||||||
config = config or UNetBackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building model with config: \n{}",
|
"Building model with config: \n{}",
|
||||||
lambda: config.to_yaml_string(), # type: ignore
|
lambda: config.to_yaml_string(), # type: ignore
|
||||||
)
|
)
|
||||||
backbone = build_backbone(config=config)
|
backbone = build_backbone(config=config)
|
||||||
|
|
||||||
classifier_head = ClassifierHead(
|
classifier_head = ClassifierHead(
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
|
|||||||
@ -82,5 +82,9 @@ class EncoderDecoderModel(BackboneModel):
|
|||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, torch.nn.Module):
|
class DetectionModel(ABC, torch.nn.Module):
|
||||||
|
backbone: BackboneModel
|
||||||
|
classifier_head: torch.nn.Module
|
||||||
|
bbox_head: torch.nn.Module
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||||
|
|||||||
@ -68,8 +68,17 @@ class TrainingModule(L.LightningModule):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
|
trainable_parameters = [
|
||||||
|
parameter
|
||||||
|
for parameter in self.parameters()
|
||||||
|
if parameter.requires_grad
|
||||||
|
]
|
||||||
|
|
||||||
|
if not trainable_parameters:
|
||||||
|
raise ValueError("No trainable parameters available.")
|
||||||
|
|
||||||
optimizer = build_optimizer(
|
optimizer = build_optimizer(
|
||||||
self.parameters(),
|
trainable_parameters,
|
||||||
config=self.train_config.optimizer,
|
config=self.train_config.optimizer,
|
||||||
)
|
)
|
||||||
scheduler = build_scheduler(
|
scheduler = build_scheduler(
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,6 +9,9 @@ from soundevent.geometry import compute_bounds
|
|||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.models.detectors import Detector
|
||||||
|
from batdetect2.models.heads import ClassifierHead
|
||||||
|
from batdetect2.train import load_model_from_checkpoint
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
|
|
||||||
@ -213,6 +217,77 @@ def test_user_can_load_checkpoint_and_finetune(
|
|||||||
assert checkpoints
|
assert checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_load_checkpoint_with_new_targets(
|
||||||
|
tmp_path: Path,
|
||||||
|
sample_targets,
|
||||||
|
) -> None:
|
||||||
|
"""User story: start from checkpoint with a new target definition."""
|
||||||
|
|
||||||
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
|
checkpoint_path = tmp_path / "base_transfer.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
|
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||||
|
api = BatDetect2API.from_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
targets=sample_targets,
|
||||||
|
)
|
||||||
|
source_detector = cast(Detector, source_model.detector)
|
||||||
|
detector = cast(Detector, api.model.detector)
|
||||||
|
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||||
|
|
||||||
|
assert api.targets is sample_targets
|
||||||
|
assert detector.num_classes == len(sample_targets.class_names)
|
||||||
|
assert (
|
||||||
|
classifier_head.classifier.out_channels
|
||||||
|
== len(sample_targets.class_names) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
source_backbone = source_detector.backbone.state_dict()
|
||||||
|
target_backbone = detector.backbone.state_dict()
|
||||||
|
assert source_backbone
|
||||||
|
for key, value in source_backbone.items():
|
||||||
|
assert key in target_backbone
|
||||||
|
torch.testing.assert_close(target_backbone[key], value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_finetune_only_heads(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations,
|
||||||
|
) -> None:
|
||||||
|
"""User story: fine-tune only prediction heads."""
|
||||||
|
|
||||||
|
api = BatDetect2API.from_config(BatDetect2Config())
|
||||||
|
finetune_dir = tmp_path / "heads_only"
|
||||||
|
|
||||||
|
api.finetune(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
trainable="heads",
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=finetune_dir,
|
||||||
|
log_dir=tmp_path / "logs",
|
||||||
|
num_epochs=1,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
detector = cast(Detector, api.model.detector)
|
||||||
|
|
||||||
|
backbone_params = list(detector.backbone.parameters())
|
||||||
|
classifier_params = list(detector.classifier_head.parameters())
|
||||||
|
bbox_params = list(detector.bbox_head.parameters())
|
||||||
|
|
||||||
|
assert backbone_params
|
||||||
|
assert classifier_params
|
||||||
|
assert bbox_params
|
||||||
|
assert all(not parameter.requires_grad for parameter in backbone_params)
|
||||||
|
assert all(parameter.requires_grad for parameter in classifier_params)
|
||||||
|
assert all(parameter.requires_grad for parameter in bbox_params)
|
||||||
|
assert list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
|
||||||
|
|
||||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
example_annotations,
|
example_annotations,
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -57,10 +56,10 @@ def dummy_targets() -> TargetProtocol:
|
|||||||
|
|
||||||
def encode_class(
|
def encode_class(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
return "bat"
|
return "bat"
|
||||||
|
|
||||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
def decode_class(self, class_label: str) -> list[data.Tag]:
|
||||||
return tag_map.get(class_label.lower(), [])
|
return tag_map.get(class_label.lower(), [])
|
||||||
|
|
||||||
def encode_roi(self, sound_event: data.SoundEventAnnotation):
|
def encode_roi(self, sound_event: data.SoundEventAnnotation):
|
||||||
@ -70,7 +69,7 @@ def dummy_targets() -> TargetProtocol:
|
|||||||
self,
|
self,
|
||||||
position,
|
position,
|
||||||
size: np.ndarray,
|
size: np.ndarray,
|
||||||
class_name: Optional[str] = None,
|
class_name: str | None = None,
|
||||||
):
|
):
|
||||||
time, freq = position
|
time, freq = position
|
||||||
width, height = size
|
width, height = size
|
||||||
@ -210,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_raw_predictions() -> List[Detection]:
|
def sample_raw_predictions() -> list[Detection]:
|
||||||
"""Manually crafted RawPrediction objects using the actual type."""
|
"""Manually crafted RawPrediction objects using the actual type."""
|
||||||
|
|
||||||
pred1_classes = xr.DataArray(
|
pred1_classes = xr.DataArray(
|
||||||
@ -282,7 +281,7 @@ def sample_raw_predictions() -> List[Detection]:
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_basic(
|
def test_convert_raw_to_sound_event_basic(
|
||||||
sample_raw_predictions: List[Detection],
|
sample_raw_predictions: list[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -325,7 +324,7 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_thresholding(
|
def test_convert_raw_to_sound_event_thresholding(
|
||||||
sample_raw_predictions: List[Detection],
|
sample_raw_predictions: list[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -353,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_no_threshold(
|
def test_convert_raw_to_sound_event_no_threshold(
|
||||||
sample_raw_predictions: List[Detection],
|
sample_raw_predictions: list[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -381,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_top_class(
|
def test_convert_raw_to_sound_event_top_class(
|
||||||
sample_raw_predictions: List[Detection],
|
sample_raw_predictions: list[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -408,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||||
sample_raw_predictions: List[Detection],
|
sample_raw_predictions: list[Detection],
|
||||||
sample_recording: data.Recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -434,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_list_to_clip_basic(
|
def test_convert_raw_list_to_clip_basic(
|
||||||
sample_raw_predictions: List[Detection],
|
sample_raw_predictions: list[Detection],
|
||||||
sample_clip: data.Clip,
|
sample_clip: data.Clip,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
@ -486,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_list_to_clip_passes_args(
|
def test_convert_raw_list_to_clip_passes_args(
|
||||||
sample_raw_predictions: List[Detection],
|
sample_raw_predictions: list[Detection],
|
||||||
sample_clip: data.Clip,
|
sample_clip: data.Clip,
|
||||||
dummy_targets: TargetProtocol,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user