mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Fix plotting after update
This commit is contained in:
parent
3043230d4f
commit
d25efdad10
@ -1,4 +1,5 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Literal, Optional, Tuple
|
from typing import List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -340,3 +341,36 @@ def match_predictions_and_annotations(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassExamples:
|
||||||
|
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def group_matches(matches: List[MatchEvaluation]) -> ClassExamples:
|
||||||
|
class_examples = ClassExamples()
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
gt_class = match.gt_class
|
||||||
|
pred_class = match.pred_class
|
||||||
|
|
||||||
|
if pred_class is None:
|
||||||
|
class_examples.false_negatives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class is None:
|
||||||
|
class_examples.false_positives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class != pred_class:
|
||||||
|
class_examples.cross_triggers.append(match)
|
||||||
|
class_examples.cross_triggers.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_examples.true_positives.append(match)
|
||||||
|
|
||||||
|
return class_examples
|
||||||
|
|||||||
@ -37,6 +37,10 @@ def plot_clip(
|
|||||||
|
|
||||||
plot_spectrogram(
|
plot_spectrogram(
|
||||||
spec,
|
spec,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
cmap=spec_cmap,
|
cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import axes
|
from matplotlib import axes
|
||||||
|
|
||||||
@ -25,10 +26,20 @@ def create_ax(
|
|||||||
|
|
||||||
def plot_spectrogram(
|
def plot_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
cmap="gray",
|
cmap="gray",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
ax.pcolormesh(spec.numpy(), cmap=cmap)
|
|
||||||
|
ax.pcolormesh(
|
||||||
|
np.linspace(start_time, end_time, spec.shape[-1], endpoint=False),
|
||||||
|
np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False),
|
||||||
|
spec.numpy(),
|
||||||
|
cmap=cmap,
|
||||||
|
)
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -100,7 +100,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError):
|
except (ValueError, AssertionError, RuntimeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_positives[:n_examples]):
|
for index, match in enumerate(false_positives[:n_examples]):
|
||||||
@ -112,7 +112,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError):
|
except (ValueError, AssertionError, RuntimeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_negatives[:n_examples]):
|
for index, match in enumerate(false_negatives[:n_examples]):
|
||||||
@ -124,7 +124,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError):
|
except (ValueError, AssertionError, RuntimeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||||
@ -136,7 +136,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError):
|
except (ValueError, AssertionError, RuntimeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user