mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Allow building new head
This commit is contained in:
parent
8e35956007
commit
ebe7e134e9
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
from typing import Literal, Sequence, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -19,7 +19,8 @@ from batdetect2.evaluate import (
|
||||
)
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.models import Model, build_model, build_model_with_new_targets
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
@ -109,6 +110,46 @@ class BatDetect2API:
|
||||
)
|
||||
return self
|
||||
|
||||
def finetune(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
trainable: Literal[
|
||||
"all", "heads", "classifier_head", "bbox_head"
|
||||
] = "heads",
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||
experiment_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune the model with trainable-parameter selection."""
|
||||
|
||||
self._set_trainable_parameters(trainable)
|
||||
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
model_config=self.config.model,
|
||||
train_config=self.config.train,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
num_epochs=num_epochs,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
)
|
||||
return self
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
@ -410,6 +451,7 @@ class BatDetect2API:
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: BatDetect2Config | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import AudioConfig
|
||||
|
||||
@ -423,7 +465,12 @@ class BatDetect2API:
|
||||
)
|
||||
config = merge_configs(base, config) if config else base
|
||||
|
||||
targets = build_targets(config=config.model.targets)
|
||||
if targets is None:
|
||||
targets = build_targets(config=config.model.targets)
|
||||
else:
|
||||
target_config = getattr(targets, "config", None)
|
||||
if target_config is not None:
|
||||
config.model.targets = target_config
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
@ -452,6 +499,16 @@ class BatDetect2API:
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
targets_changed = targets is not None or (
|
||||
config.model.targets.model_dump(mode="json")
|
||||
!= model_config.targets.model_dump(mode="json")
|
||||
)
|
||||
if targets_changed:
|
||||
model = build_model_with_new_targets(
|
||||
model=model,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
model.preprocessor = preprocessor
|
||||
model.postprocessor = postprocessor
|
||||
model.targets = targets
|
||||
@ -467,3 +524,25 @@ class BatDetect2API:
|
||||
formatter=formatter,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
|
||||
def _set_trainable_parameters(
|
||||
self,
|
||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||
) -> None:
|
||||
detector = cast(Detector, self.model.detector)
|
||||
|
||||
for parameter in detector.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
if trainable == "all":
|
||||
for parameter in detector.parameters():
|
||||
parameter.requires_grad = True
|
||||
return
|
||||
|
||||
if trainable in {"heads", "classifier_head"}:
|
||||
for parameter in detector.classifier_head.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
if trainable in {"heads", "bbox_head"}:
|
||||
for parameter in detector.bbox_head.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
@ -100,6 +100,7 @@ __all__ = [
|
||||
"Model",
|
||||
"ModelConfig",
|
||||
"build_model",
|
||||
"build_model_with_new_targets",
|
||||
]
|
||||
|
||||
|
||||
@ -274,3 +275,21 @@ def build_model(
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
def build_model_with_new_targets(
|
||||
model: Model,
|
||||
targets: TargetProtocol,
|
||||
) -> Model:
|
||||
"""Build a new model with a different target set."""
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
backbone=model.detector.backbone,
|
||||
)
|
||||
|
||||
return Model(
|
||||
detector=detector,
|
||||
postprocessor=model.postprocessor,
|
||||
preprocessor=model.preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
@ -135,7 +135,9 @@ class Detector(DetectionModel):
|
||||
|
||||
|
||||
def build_detector(
|
||||
num_classes: int, config: BackboneConfig | None = None
|
||||
num_classes: int,
|
||||
config: BackboneConfig | None = None,
|
||||
backbone: BackboneModel | None = None,
|
||||
) -> DetectionModel:
|
||||
"""Build a complete BatDetect2 detection model.
|
||||
|
||||
@ -165,13 +167,14 @@ def build_detector(
|
||||
If ``num_classes`` is not positive, or if the backbone
|
||||
configuration is invalid.
|
||||
"""
|
||||
config = config or UNetBackboneConfig()
|
||||
if backbone is None:
|
||||
config = config or UNetBackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(), # type: ignore
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(), # type: ignore
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
classifier_head = ClassifierHead(
|
||||
num_classes=num_classes,
|
||||
in_channels=backbone.out_channels,
|
||||
|
||||
@ -82,5 +82,9 @@ class EncoderDecoderModel(BackboneModel):
|
||||
|
||||
|
||||
class DetectionModel(ABC, torch.nn.Module):
|
||||
backbone: BackboneModel
|
||||
classifier_head: torch.nn.Module
|
||||
bbox_head: torch.nn.Module
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||
|
||||
@ -68,8 +68,17 @@ class TrainingModule(L.LightningModule):
|
||||
return outputs
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_parameters = [
|
||||
parameter
|
||||
for parameter in self.parameters()
|
||||
if parameter.requires_grad
|
||||
]
|
||||
|
||||
if not trainable_parameters:
|
||||
raise ValueError("No trainable parameters available.")
|
||||
|
||||
optimizer = build_optimizer(
|
||||
self.parameters(),
|
||||
trainable_parameters,
|
||||
config=self.train_config.optimizer,
|
||||
)
|
||||
scheduler = build_scheduler(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import lightning as L
|
||||
import numpy as np
|
||||
@ -8,6 +9,9 @@ from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.models.heads import ClassifierHead
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
|
||||
@ -213,6 +217,77 @@ def test_user_can_load_checkpoint_and_finetune(
|
||||
assert checkpoints
|
||||
|
||||
|
||||
def test_user_can_load_checkpoint_with_new_targets(
|
||||
tmp_path: Path,
|
||||
sample_targets,
|
||||
) -> None:
|
||||
"""User story: start from checkpoint with a new target definition."""
|
||||
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
checkpoint_path = tmp_path / "base_transfer.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
targets=sample_targets,
|
||||
)
|
||||
source_detector = cast(Detector, source_model.detector)
|
||||
detector = cast(Detector, api.model.detector)
|
||||
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||
|
||||
assert api.targets is sample_targets
|
||||
assert detector.num_classes == len(sample_targets.class_names)
|
||||
assert (
|
||||
classifier_head.classifier.out_channels
|
||||
== len(sample_targets.class_names) + 1
|
||||
)
|
||||
|
||||
source_backbone = source_detector.backbone.state_dict()
|
||||
target_backbone = detector.backbone.state_dict()
|
||||
assert source_backbone
|
||||
for key, value in source_backbone.items():
|
||||
assert key in target_backbone
|
||||
torch.testing.assert_close(target_backbone[key], value)
|
||||
|
||||
|
||||
def test_user_can_finetune_only_heads(
|
||||
tmp_path: Path,
|
||||
example_annotations,
|
||||
) -> None:
|
||||
"""User story: fine-tune only prediction heads."""
|
||||
|
||||
api = BatDetect2API.from_config(BatDetect2Config())
|
||||
finetune_dir = tmp_path / "heads_only"
|
||||
|
||||
api.finetune(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
trainable="heads",
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=finetune_dir,
|
||||
log_dir=tmp_path / "logs",
|
||||
num_epochs=1,
|
||||
seed=0,
|
||||
)
|
||||
detector = cast(Detector, api.model.detector)
|
||||
|
||||
backbone_params = list(detector.backbone.parameters())
|
||||
classifier_params = list(detector.classifier_head.parameters())
|
||||
bbox_params = list(detector.bbox_head.parameters())
|
||||
|
||||
assert backbone_params
|
||||
assert classifier_params
|
||||
assert bbox_params
|
||||
assert all(not parameter.requires_grad for parameter in backbone_params)
|
||||
assert all(parameter.requires_grad for parameter in classifier_params)
|
||||
assert all(parameter.requires_grad for parameter in bbox_params)
|
||||
assert list(finetune_dir.rglob("*.ckpt"))
|
||||
|
||||
|
||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||
api_v2: BatDetect2API,
|
||||
example_annotations,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -57,10 +56,10 @@ def dummy_targets() -> TargetProtocol:
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
return "bat"
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
def decode_class(self, class_label: str) -> list[data.Tag]:
|
||||
return tag_map.get(class_label.lower(), [])
|
||||
|
||||
def encode_roi(self, sound_event: data.SoundEventAnnotation):
|
||||
@ -70,7 +69,7 @@ def dummy_targets() -> TargetProtocol:
|
||||
self,
|
||||
position,
|
||||
size: np.ndarray,
|
||||
class_name: Optional[str] = None,
|
||||
class_name: str | None = None,
|
||||
):
|
||||
time, freq = position
|
||||
width, height = size
|
||||
@ -210,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_raw_predictions() -> List[Detection]:
|
||||
def sample_raw_predictions() -> list[Detection]:
|
||||
"""Manually crafted RawPrediction objects using the actual type."""
|
||||
|
||||
pred1_classes = xr.DataArray(
|
||||
@ -282,7 +281,7 @@ def sample_raw_predictions() -> List[Detection]:
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_raw_predictions: list[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -325,7 +324,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_thresholding(
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_raw_predictions: list[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -353,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_no_threshold(
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_raw_predictions: list[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -381,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_top_class(
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_raw_predictions: list[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -408,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_raw_predictions: list[Detection],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -434,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_basic(
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_raw_predictions: list[Detection],
|
||||
sample_clip: data.Clip,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
@ -486,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_passes_args(
|
||||
sample_raw_predictions: List[Detection],
|
||||
sample_raw_predictions: list[Detection],
|
||||
sample_clip: data.Clip,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user