mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Clean up the detector head
This commit is contained in:
parent
98f83e8b34
commit
d4f249366e
@ -88,7 +88,9 @@ dev = [
|
|||||||
"python-lsp-server>=1.13.0",
|
"python-lsp-server>=1.13.0",
|
||||||
]
|
]
|
||||||
dvclive = ["dvclive>=3.48.2"]
|
dvclive = ["dvclive>=3.48.2"]
|
||||||
mlflow = ["mlflow>=3.1.1"]
|
mlflow = [
|
||||||
|
"mlflow>=3.1.1",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
|
|||||||
@ -18,7 +18,7 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@ -39,8 +39,6 @@ class Detector(DetectionModel):
|
|||||||
the `classifier_head`).
|
the `classifier_head`).
|
||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
The prediction head responsible for generating class probabilities.
|
The prediction head responsible for generating class probabilities.
|
||||||
detector_head : DetectorHead
|
|
||||||
The prediction head responsible for generating detection probabilities.
|
|
||||||
bbox_head : BBoxHead
|
bbox_head : BBoxHead
|
||||||
The prediction head responsible for generating bounding box size
|
The prediction head responsible for generating bounding box size
|
||||||
predictions.
|
predictions.
|
||||||
@ -52,7 +50,6 @@ class Detector(DetectionModel):
|
|||||||
self,
|
self,
|
||||||
backbone: BackboneModel,
|
backbone: BackboneModel,
|
||||||
classifier_head: ClassifierHead,
|
classifier_head: ClassifierHead,
|
||||||
detector_head: DetectorHead,
|
|
||||||
bbox_head: BBoxHead,
|
bbox_head: BBoxHead,
|
||||||
):
|
):
|
||||||
"""Initialize the Detector model.
|
"""Initialize the Detector model.
|
||||||
@ -68,8 +65,6 @@ class Detector(DetectionModel):
|
|||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
An initialized classification head module. The number of classes
|
An initialized classification head module. The number of classes
|
||||||
is inferred from this head.
|
is inferred from this head.
|
||||||
detector_head : DetectorHead
|
|
||||||
An initialized detection head module.
|
|
||||||
bbox_head : BBoxHead
|
bbox_head : BBoxHead
|
||||||
An initialized bounding box size prediction head module.
|
An initialized bounding box size prediction head module.
|
||||||
|
|
||||||
@ -83,7 +78,6 @@ class Detector(DetectionModel):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.num_classes = classifier_head.num_classes
|
self.num_classes = classifier_head.num_classes
|
||||||
self.classifier_head = classifier_head
|
self.classifier_head = classifier_head
|
||||||
self.detector_head = detector_head
|
|
||||||
self.bbox_head = bbox_head
|
self.bbox_head = bbox_head
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
@ -115,7 +109,7 @@ class Detector(DetectionModel):
|
|||||||
"""
|
"""
|
||||||
features = self.backbone(spec)
|
features = self.backbone(spec)
|
||||||
classification = self.classifier_head(features)
|
classification = self.classifier_head(features)
|
||||||
detection = classification.sum(dim=1, keep_dim=True)
|
detection = classification.sum(dim=1, keepdim=True)
|
||||||
size_preds = self.bbox_head(features)
|
size_preds = self.bbox_head(features)
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
detection_probs=detection,
|
detection_probs=detection,
|
||||||
@ -159,15 +153,11 @@ def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
|||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
)
|
)
|
||||||
detector_head = DetectorHead(
|
|
||||||
in_channels=backbone.out_channels,
|
|
||||||
)
|
|
||||||
bbox_head = BBoxHead(
|
bbox_head = BBoxHead(
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
)
|
)
|
||||||
return Detector(
|
return Detector(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
classifier_head=classifier_head,
|
classifier_head=classifier_head,
|
||||||
detector_head=detector_head,
|
|
||||||
bbox_head=bbox_head,
|
bbox_head=bbox_head,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -232,7 +232,12 @@ def _get_image_plotter(logger: Logger):
|
|||||||
|
|
||||||
def plot_figure(name, figure, step):
|
def plot_figure(name, figure, step):
|
||||||
image = _convert_figure_to_image(figure)
|
image = _convert_figure_to_image(figure)
|
||||||
return logger.experiment.log_image(image, key=name, step=step)
|
return logger.experiment.log_image(
|
||||||
|
run_id=logger.run_id,
|
||||||
|
image=image,
|
||||||
|
key=name,
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
|
||||||
return plot_figure
|
return plot_figure
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user