From f0b0f2837936983ae46573830aec43e4748231b0 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Sat, 25 Feb 2023 20:10:44 +0000 Subject: [PATCH] Changed the model API so that they always return the features --- bat_detect/detector/models.py | 13 ++++--------- bat_detect/detector/post_process.py | 2 +- bat_detect/utils/detector_utils.py | 2 +- tests/test_bat_detect.py | 6 ++---- 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/bat_detect/detector/models.py b/bat_detect/detector/models.py index b9a18dd..47ec728 100644 --- a/bat_detect/detector/models.py +++ b/bat_detect/detector/models.py @@ -42,10 +42,7 @@ class ModelOutput(NamedTuple): pred_class_un_norm: torch.Tensor """Tensor with predicted class probabilities before softmax.""" - pred_emb: Optional[torch.Tensor] - """Tensor with embeddings.""" - - features: Optional[torch.Tensor] + features: torch.Tensor """Tensor with intermediate features.""" @@ -207,8 +204,7 @@ class Net2DFast(nn.Module): pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_class=comb, pred_class_un_norm=cls, - pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, - features=x if return_feats else None, + features=x, ) @@ -315,8 +311,7 @@ class Net2DFastNoAttn(nn.Module): pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_class=comb, pred_class_un_norm=cls, - pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, - features=x if return_feats else None, + features=x, ) @@ -428,5 +423,5 @@ class Net2DFastNoCoordConv(nn.Module): pred_class=comb, pred_class_un_norm=cls, pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, - features=x if return_feats else None, + features=x, ) diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py index eb8cfbc..763aeca 100644 --- a/bat_detect/detector/post_process.py +++ b/bat_detect/detector/post_process.py @@ -121,7 +121,7 @@ def run_nms( the features. Each element of the lists corresponds to one element of the batch. """ - pred_det, pred_size, pred_class, _, _, features = outputs + pred_det, pred_size, pred_class, _, features = outputs pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"]) freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[ diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index 9e34ea2..f1a8c00 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -696,7 +696,7 @@ def _process_spectrogram( ) -> Tuple[List[Annotation], List[np.ndarray]]: # evaluate model with torch.no_grad(): - outputs = model(spec, return_feats=config["cnn_features"]) + outputs = model(spec) # run non-max suppression pred_nms_list, features = pp.run_nms( diff --git a/tests/test_bat_detect.py b/tests/test_bat_detect.py index 6440261..902c2df 100644 --- a/tests/test_bat_detect.py +++ b/tests/test_bat_detect.py @@ -174,8 +174,7 @@ def test_process_spectrogram_with_model(): assert features is not None assert isinstance(features, list) - # By default will not return cnn features - assert len(features) == 0 + assert len(features) == 1 def test_process_audio_with_model(): @@ -205,8 +204,7 @@ def test_process_audio_with_model(): assert features is not None assert isinstance(features, list) - # By default will not return cnn features - assert len(features) == 0 + assert len(features) == 1 assert spec is not None assert isinstance(spec, torch.Tensor)