mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Changed the model API so that they always return the features
This commit is contained in:
parent
0eecf54a94
commit
f0b0f28379
@ -42,10 +42,7 @@ class ModelOutput(NamedTuple):
|
||||
pred_class_un_norm: torch.Tensor
|
||||
"""Tensor with predicted class probabilities before softmax."""
|
||||
|
||||
pred_emb: Optional[torch.Tensor]
|
||||
"""Tensor with embeddings."""
|
||||
|
||||
features: Optional[torch.Tensor]
|
||||
features: torch.Tensor
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
@ -207,8 +204,7 @@ class Net2DFast(nn.Module):
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||
features=x if return_feats else None,
|
||||
features=x,
|
||||
)
|
||||
|
||||
|
||||
@ -315,8 +311,7 @@ class Net2DFastNoAttn(nn.Module):
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||
features=x if return_feats else None,
|
||||
features=x,
|
||||
)
|
||||
|
||||
|
||||
@ -428,5 +423,5 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||
features=x if return_feats else None,
|
||||
features=x,
|
||||
)
|
||||
|
@ -121,7 +121,7 @@ def run_nms(
|
||||
the features. Each element of the lists corresponds to one
|
||||
element of the batch.
|
||||
"""
|
||||
pred_det, pred_size, pred_class, _, _, features = outputs
|
||||
pred_det, pred_size, pred_class, _, features = outputs
|
||||
|
||||
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
||||
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
||||
|
@ -696,7 +696,7 @@ def _process_spectrogram(
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec, return_feats=config["cnn_features"])
|
||||
outputs = model(spec)
|
||||
|
||||
# run non-max suppression
|
||||
pred_nms_list, features = pp.run_nms(
|
||||
|
@ -174,8 +174,7 @@ def test_process_spectrogram_with_model():
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
# By default will not return cnn features
|
||||
assert len(features) == 0
|
||||
assert len(features) == 1
|
||||
|
||||
|
||||
def test_process_audio_with_model():
|
||||
@ -205,8 +204,7 @@ def test_process_audio_with_model():
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
# By default will not return cnn features
|
||||
assert len(features) == 0
|
||||
assert len(features) == 1
|
||||
|
||||
assert spec is not None
|
||||
assert isinstance(spec, torch.Tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user