diff --git a/pyproject.toml b/pyproject.toml index c147cf4..d6993fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,9 @@ dev = [ "python-lsp-server>=1.13.0", ] dvclive = ["dvclive>=3.48.2"] -mlflow = ["mlflow>=3.1.1"] +mlflow = [ + "mlflow>=3.1.1", +] [tool.ruff] line-length = 79 diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 854f71e..1b94f0a 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -18,7 +18,7 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. import torch -from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead +from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput @@ -39,8 +39,6 @@ class Detector(DetectionModel): the `classifier_head`). classifier_head : ClassifierHead The prediction head responsible for generating class probabilities. - detector_head : DetectorHead - The prediction head responsible for generating detection probabilities. bbox_head : BBoxHead The prediction head responsible for generating bounding box size predictions. @@ -52,7 +50,6 @@ class Detector(DetectionModel): self, backbone: BackboneModel, classifier_head: ClassifierHead, - detector_head: DetectorHead, bbox_head: BBoxHead, ): """Initialize the Detector model. @@ -68,8 +65,6 @@ class Detector(DetectionModel): classifier_head : ClassifierHead An initialized classification head module. The number of classes is inferred from this head. - detector_head : DetectorHead - An initialized detection head module. bbox_head : BBoxHead An initialized bounding box size prediction head module. @@ -83,7 +78,6 @@ class Detector(DetectionModel): self.backbone = backbone self.num_classes = classifier_head.num_classes self.classifier_head = classifier_head - self.detector_head = detector_head self.bbox_head = bbox_head def forward(self, spec: torch.Tensor) -> ModelOutput: @@ -115,7 +109,7 @@ class Detector(DetectionModel): """ features = self.backbone(spec) classification = self.classifier_head(features) - detection = classification.sum(dim=1, keep_dim=True) + detection = classification.sum(dim=1, keepdim=True) size_preds = self.bbox_head(features) return ModelOutput( detection_probs=detection, @@ -159,15 +153,11 @@ def build_detector(num_classes: int, backbone: BackboneModel) -> Detector: num_classes=num_classes, in_channels=backbone.out_channels, ) - detector_head = DetectorHead( - in_channels=backbone.out_channels, - ) bbox_head = BBoxHead( in_channels=backbone.out_channels, ) return Detector( backbone=backbone, classifier_head=classifier_head, - detector_head=detector_head, bbox_head=bbox_head, ) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 99d4967..554152d 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -232,7 +232,12 @@ def _get_image_plotter(logger: Logger): def plot_figure(name, figure, step): image = _convert_figure_to_image(figure) - return logger.experiment.log_image(image, key=name, step=step) + return logger.experiment.log_image( + run_id=logger.run_id, + image=image, + key=name, + step=step, + ) return plot_figure