mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Changed the model API so that they always return the features
This commit is contained in:
parent
0eecf54a94
commit
f0b0f28379
@ -42,10 +42,7 @@ class ModelOutput(NamedTuple):
|
|||||||
pred_class_un_norm: torch.Tensor
|
pred_class_un_norm: torch.Tensor
|
||||||
"""Tensor with predicted class probabilities before softmax."""
|
"""Tensor with predicted class probabilities before softmax."""
|
||||||
|
|
||||||
pred_emb: Optional[torch.Tensor]
|
features: torch.Tensor
|
||||||
"""Tensor with embeddings."""
|
|
||||||
|
|
||||||
features: Optional[torch.Tensor]
|
|
||||||
"""Tensor with intermediate features."""
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
@ -207,8 +204,7 @@ class Net2DFast(nn.Module):
|
|||||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||||
pred_class=comb,
|
pred_class=comb,
|
||||||
pred_class_un_norm=cls,
|
pred_class_un_norm=cls,
|
||||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
features=x,
|
||||||
features=x if return_feats else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -315,8 +311,7 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||||
pred_class=comb,
|
pred_class=comb,
|
||||||
pred_class_un_norm=cls,
|
pred_class_un_norm=cls,
|
||||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
features=x,
|
||||||
features=x if return_feats else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -428,5 +423,5 @@ class Net2DFastNoCoordConv(nn.Module):
|
|||||||
pred_class=comb,
|
pred_class=comb,
|
||||||
pred_class_un_norm=cls,
|
pred_class_un_norm=cls,
|
||||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||||
features=x if return_feats else None,
|
features=x,
|
||||||
)
|
)
|
||||||
|
@ -121,7 +121,7 @@ def run_nms(
|
|||||||
the features. Each element of the lists corresponds to one
|
the features. Each element of the lists corresponds to one
|
||||||
element of the batch.
|
element of the batch.
|
||||||
"""
|
"""
|
||||||
pred_det, pred_size, pred_class, _, _, features = outputs
|
pred_det, pred_size, pred_class, _, features = outputs
|
||||||
|
|
||||||
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
||||||
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
||||||
|
@ -696,7 +696,7 @@ def _process_spectrogram(
|
|||||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||||
# evaluate model
|
# evaluate model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(spec, return_feats=config["cnn_features"])
|
outputs = model(spec)
|
||||||
|
|
||||||
# run non-max suppression
|
# run non-max suppression
|
||||||
pred_nms_list, features = pp.run_nms(
|
pred_nms_list, features = pp.run_nms(
|
||||||
|
@ -174,8 +174,7 @@ def test_process_spectrogram_with_model():
|
|||||||
|
|
||||||
assert features is not None
|
assert features is not None
|
||||||
assert isinstance(features, list)
|
assert isinstance(features, list)
|
||||||
# By default will not return cnn features
|
assert len(features) == 1
|
||||||
assert len(features) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_audio_with_model():
|
def test_process_audio_with_model():
|
||||||
@ -205,8 +204,7 @@ def test_process_audio_with_model():
|
|||||||
|
|
||||||
assert features is not None
|
assert features is not None
|
||||||
assert isinstance(features, list)
|
assert isinstance(features, list)
|
||||||
# By default will not return cnn features
|
assert len(features) == 1
|
||||||
assert len(features) == 0
|
|
||||||
|
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
assert isinstance(spec, torch.Tensor)
|
assert isinstance(spec, torch.Tensor)
|
||||||
|
Loading…
Reference in New Issue
Block a user