Clean up the detector head

This commit is contained in:
mbsantiago 2025-08-18 16:07:51 +01:00
parent 98f83e8b34
commit d4f249366e
3 changed files with 11 additions and 14 deletions

View File

@ -88,7 +88,9 @@ dev = [
"python-lsp-server>=1.13.0", "python-lsp-server>=1.13.0",
] ]
dvclive = ["dvclive>=3.48.2"] dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"] mlflow = [
"mlflow>=3.1.1",
]
[tool.ruff] [tool.ruff]
line-length = 79 line-length = 79

View File

@ -18,7 +18,7 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
import torch import torch
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
@ -39,8 +39,6 @@ class Detector(DetectionModel):
the `classifier_head`). the `classifier_head`).
classifier_head : ClassifierHead classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities. The prediction head responsible for generating class probabilities.
detector_head : DetectorHead
The prediction head responsible for generating detection probabilities.
bbox_head : BBoxHead bbox_head : BBoxHead
The prediction head responsible for generating bounding box size The prediction head responsible for generating bounding box size
predictions. predictions.
@ -52,7 +50,6 @@ class Detector(DetectionModel):
self, self,
backbone: BackboneModel, backbone: BackboneModel,
classifier_head: ClassifierHead, classifier_head: ClassifierHead,
detector_head: DetectorHead,
bbox_head: BBoxHead, bbox_head: BBoxHead,
): ):
"""Initialize the Detector model. """Initialize the Detector model.
@ -68,8 +65,6 @@ class Detector(DetectionModel):
classifier_head : ClassifierHead classifier_head : ClassifierHead
An initialized classification head module. The number of classes An initialized classification head module. The number of classes
is inferred from this head. is inferred from this head.
detector_head : DetectorHead
An initialized detection head module.
bbox_head : BBoxHead bbox_head : BBoxHead
An initialized bounding box size prediction head module. An initialized bounding box size prediction head module.
@ -83,7 +78,6 @@ class Detector(DetectionModel):
self.backbone = backbone self.backbone = backbone
self.num_classes = classifier_head.num_classes self.num_classes = classifier_head.num_classes
self.classifier_head = classifier_head self.classifier_head = classifier_head
self.detector_head = detector_head
self.bbox_head = bbox_head self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
@ -115,7 +109,7 @@ class Detector(DetectionModel):
""" """
features = self.backbone(spec) features = self.backbone(spec)
classification = self.classifier_head(features) classification = self.classifier_head(features)
detection = classification.sum(dim=1, keep_dim=True) detection = classification.sum(dim=1, keepdim=True)
size_preds = self.bbox_head(features) size_preds = self.bbox_head(features)
return ModelOutput( return ModelOutput(
detection_probs=detection, detection_probs=detection,
@ -159,15 +153,11 @@ def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
num_classes=num_classes, num_classes=num_classes,
in_channels=backbone.out_channels, in_channels=backbone.out_channels,
) )
detector_head = DetectorHead(
in_channels=backbone.out_channels,
)
bbox_head = BBoxHead( bbox_head = BBoxHead(
in_channels=backbone.out_channels, in_channels=backbone.out_channels,
) )
return Detector( return Detector(
backbone=backbone, backbone=backbone,
classifier_head=classifier_head, classifier_head=classifier_head,
detector_head=detector_head,
bbox_head=bbox_head, bbox_head=bbox_head,
) )

View File

@ -232,7 +232,12 @@ def _get_image_plotter(logger: Logger):
def plot_figure(name, figure, step): def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure) image = _convert_figure_to_image(figure)
return logger.experiment.log_image(image, key=name, step=step) return logger.experiment.log_image(
run_id=logger.run_id,
image=image,
key=name,
step=step,
)
return plot_figure return plot_figure