Clean up the detector head

This commit is contained in:
mbsantiago 2025-08-18 16:07:51 +01:00
parent 98f83e8b34
commit d4f249366e
3 changed files with 11 additions and 14 deletions

View File

@ -88,7 +88,9 @@ dev = [
"python-lsp-server>=1.13.0",
]
dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"]
mlflow = [
"mlflow>=3.1.1",
]
[tool.ruff]
line-length = 79

View File

@ -18,7 +18,7 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
import torch
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
@ -39,8 +39,6 @@ class Detector(DetectionModel):
the `classifier_head`).
classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities.
detector_head : DetectorHead
The prediction head responsible for generating detection probabilities.
bbox_head : BBoxHead
The prediction head responsible for generating bounding box size
predictions.
@ -52,7 +50,6 @@ class Detector(DetectionModel):
self,
backbone: BackboneModel,
classifier_head: ClassifierHead,
detector_head: DetectorHead,
bbox_head: BBoxHead,
):
"""Initialize the Detector model.
@ -68,8 +65,6 @@ class Detector(DetectionModel):
classifier_head : ClassifierHead
An initialized classification head module. The number of classes
is inferred from this head.
detector_head : DetectorHead
An initialized detection head module.
bbox_head : BBoxHead
An initialized bounding box size prediction head module.
@ -83,7 +78,6 @@ class Detector(DetectionModel):
self.backbone = backbone
self.num_classes = classifier_head.num_classes
self.classifier_head = classifier_head
self.detector_head = detector_head
self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput:
@ -115,7 +109,7 @@ class Detector(DetectionModel):
"""
features = self.backbone(spec)
classification = self.classifier_head(features)
detection = classification.sum(dim=1, keep_dim=True)
detection = classification.sum(dim=1, keepdim=True)
size_preds = self.bbox_head(features)
return ModelOutput(
detection_probs=detection,
@ -159,15 +153,11 @@ def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
num_classes=num_classes,
in_channels=backbone.out_channels,
)
detector_head = DetectorHead(
in_channels=backbone.out_channels,
)
bbox_head = BBoxHead(
in_channels=backbone.out_channels,
)
return Detector(
backbone=backbone,
classifier_head=classifier_head,
detector_head=detector_head,
bbox_head=bbox_head,
)

View File

@ -232,7 +232,12 @@ def _get_image_plotter(logger: Logger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(image, key=name, step=step)
return logger.experiment.log_image(
run_id=logger.run_id,
image=image,
key=name,
step=step,
)
return plot_figure