Changed the model API so that they always return the features

This commit is contained in:
Santiago Martinez 2023-02-25 20:10:44 +00:00
parent 0eecf54a94
commit f0b0f28379
4 changed files with 8 additions and 15 deletions

View File

@ -42,10 +42,7 @@ class ModelOutput(NamedTuple):
pred_class_un_norm: torch.Tensor
"""Tensor with predicted class probabilities before softmax."""
pred_emb: Optional[torch.Tensor]
"""Tensor with embeddings."""
features: Optional[torch.Tensor]
features: torch.Tensor
"""Tensor with intermediate features."""
@ -207,8 +204,7 @@ class Net2DFast(nn.Module):
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb,
pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x if return_feats else None,
features=x,
)
@ -315,8 +311,7 @@ class Net2DFastNoAttn(nn.Module):
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb,
pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x if return_feats else None,
features=x,
)
@ -428,5 +423,5 @@ class Net2DFastNoCoordConv(nn.Module):
pred_class=comb,
pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x if return_feats else None,
features=x,
)

View File

@ -121,7 +121,7 @@ def run_nms(
the features. Each element of the lists corresponds to one
element of the batch.
"""
pred_det, pred_size, pred_class, _, _, features = outputs
pred_det, pred_size, pred_class, _, features = outputs
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[

View File

@ -696,7 +696,7 @@ def _process_spectrogram(
) -> Tuple[List[Annotation], List[np.ndarray]]:
# evaluate model
with torch.no_grad():
outputs = model(spec, return_feats=config["cnn_features"])
outputs = model(spec)
# run non-max suppression
pred_nms_list, features = pp.run_nms(

View File

@ -174,8 +174,7 @@ def test_process_spectrogram_with_model():
assert features is not None
assert isinstance(features, list)
# By default will not return cnn features
assert len(features) == 0
assert len(features) == 1
def test_process_audio_with_model():
@ -205,8 +204,7 @@ def test_process_audio_with_model():
assert features is not None
assert isinstance(features, list)
# By default will not return cnn features
assert len(features) == 0
assert len(features) == 1
assert spec is not None
assert isinstance(spec, torch.Tensor)