Changed the model API so that they always return the features

This commit is contained in:
Santiago Martinez 2023-02-25 20:10:44 +00:00
parent 0eecf54a94
commit f0b0f28379
4 changed files with 8 additions and 15 deletions

View File

@ -42,10 +42,7 @@ class ModelOutput(NamedTuple):
pred_class_un_norm: torch.Tensor pred_class_un_norm: torch.Tensor
"""Tensor with predicted class probabilities before softmax.""" """Tensor with predicted class probabilities before softmax."""
pred_emb: Optional[torch.Tensor] features: torch.Tensor
"""Tensor with embeddings."""
features: Optional[torch.Tensor]
"""Tensor with intermediate features.""" """Tensor with intermediate features."""
@ -207,8 +204,7 @@ class Net2DFast(nn.Module):
pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, features=x,
features=x if return_feats else None,
) )
@ -315,8 +311,7 @@ class Net2DFastNoAttn(nn.Module):
pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, features=x,
features=x if return_feats else None,
) )
@ -428,5 +423,5 @@ class Net2DFastNoCoordConv(nn.Module):
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x if return_feats else None, features=x,
) )

View File

@ -121,7 +121,7 @@ def run_nms(
the features. Each element of the lists corresponds to one the features. Each element of the lists corresponds to one
element of the batch. element of the batch.
""" """
pred_det, pred_size, pred_class, _, _, features = outputs pred_det, pred_size, pred_class, _, features = outputs
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"]) pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[ freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[

View File

@ -696,7 +696,7 @@ def _process_spectrogram(
) -> Tuple[List[Annotation], List[np.ndarray]]: ) -> Tuple[List[Annotation], List[np.ndarray]]:
# evaluate model # evaluate model
with torch.no_grad(): with torch.no_grad():
outputs = model(spec, return_feats=config["cnn_features"]) outputs = model(spec)
# run non-max suppression # run non-max suppression
pred_nms_list, features = pp.run_nms( pred_nms_list, features = pp.run_nms(

View File

@ -174,8 +174,7 @@ def test_process_spectrogram_with_model():
assert features is not None assert features is not None
assert isinstance(features, list) assert isinstance(features, list)
# By default will not return cnn features assert len(features) == 1
assert len(features) == 0
def test_process_audio_with_model(): def test_process_audio_with_model():
@ -205,8 +204,7 @@ def test_process_audio_with_model():
assert features is not None assert features is not None
assert isinstance(features, list) assert isinstance(features, list)
# By default will not return cnn features assert len(features) == 1
assert len(features) == 0
assert spec is not None assert spec is not None
assert isinstance(spec, torch.Tensor) assert isinstance(spec, torch.Tensor)