mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Clean up the detector head
This commit is contained in:
parent
98f83e8b34
commit
d4f249366e
@ -88,7 +88,9 @@ dev = [
|
||||
"python-lsp-server>=1.13.0",
|
||||
]
|
||||
dvclive = ["dvclive>=3.48.2"]
|
||||
mlflow = ["mlflow>=3.1.1"]
|
||||
mlflow = [
|
||||
"mlflow>=3.1.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
|
||||
@ -18,7 +18,7 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
|
||||
@ -39,8 +39,6 @@ class Detector(DetectionModel):
|
||||
the `classifier_head`).
|
||||
classifier_head : ClassifierHead
|
||||
The prediction head responsible for generating class probabilities.
|
||||
detector_head : DetectorHead
|
||||
The prediction head responsible for generating detection probabilities.
|
||||
bbox_head : BBoxHead
|
||||
The prediction head responsible for generating bounding box size
|
||||
predictions.
|
||||
@ -52,7 +50,6 @@ class Detector(DetectionModel):
|
||||
self,
|
||||
backbone: BackboneModel,
|
||||
classifier_head: ClassifierHead,
|
||||
detector_head: DetectorHead,
|
||||
bbox_head: BBoxHead,
|
||||
):
|
||||
"""Initialize the Detector model.
|
||||
@ -68,8 +65,6 @@ class Detector(DetectionModel):
|
||||
classifier_head : ClassifierHead
|
||||
An initialized classification head module. The number of classes
|
||||
is inferred from this head.
|
||||
detector_head : DetectorHead
|
||||
An initialized detection head module.
|
||||
bbox_head : BBoxHead
|
||||
An initialized bounding box size prediction head module.
|
||||
|
||||
@ -83,7 +78,6 @@ class Detector(DetectionModel):
|
||||
self.backbone = backbone
|
||||
self.num_classes = classifier_head.num_classes
|
||||
self.classifier_head = classifier_head
|
||||
self.detector_head = detector_head
|
||||
self.bbox_head = bbox_head
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
@ -115,7 +109,7 @@ class Detector(DetectionModel):
|
||||
"""
|
||||
features = self.backbone(spec)
|
||||
classification = self.classifier_head(features)
|
||||
detection = classification.sum(dim=1, keep_dim=True)
|
||||
detection = classification.sum(dim=1, keepdim=True)
|
||||
size_preds = self.bbox_head(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection,
|
||||
@ -159,15 +153,11 @@ def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
||||
num_classes=num_classes,
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
detector_head = DetectorHead(
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
bbox_head = BBoxHead(
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
return Detector(
|
||||
backbone=backbone,
|
||||
classifier_head=classifier_head,
|
||||
detector_head=detector_head,
|
||||
bbox_head=bbox_head,
|
||||
)
|
||||
|
||||
@ -232,7 +232,12 @@ def _get_image_plotter(logger: Logger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
image = _convert_figure_to_image(figure)
|
||||
return logger.experiment.log_image(image, key=name, step=step)
|
||||
return logger.experiment.log_image(
|
||||
run_id=logger.run_id,
|
||||
image=image,
|
||||
key=name,
|
||||
step=step,
|
||||
)
|
||||
|
||||
return plot_figure
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user