Compare commits

..

No commits in common. "ff754a1269d689f113f05c5cdcb08b2cb8b6df87" and "d25efdad10f9927f0364bfcea629f32090270e3c" have entirely different histories.

6 changed files with 10 additions and 45 deletions

View File

@ -4,9 +4,7 @@ from matplotlib.axes import Axes
from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import create_ax
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"plot_clip_annotation",
@ -45,31 +43,3 @@ def plot_clip_annotation(
facecolor="none" if not fill else None,
)
return ax
def plot_anchor_points(
clip_annotation: data.ClipAnnotation,
targets: TargetProtocol,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
size: int = 1,
color: str = "red",
marker: str = "x",
alpha: float = 1,
) -> Axes:
ax = create_ax(ax=ax, figsize=figsize)
positions = []
for sound_event in clip_annotation.sound_events:
if not targets.filter(sound_event):
continue
sound_event = targets.transform(sound_event)
position, _ = targets.encode_roi(sound_event)
positions.append(position)
X, Y = zip(*positions)
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
return ax

View File

@ -1,6 +1,6 @@
"""General plotting utilities."""
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
@ -25,7 +25,7 @@ def create_ax(
def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray],
spec: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float,
@ -34,15 +34,12 @@ def plot_spectrogram(
figsize: Optional[Tuple[int, int]] = None,
cmap="gray",
) -> axes.Axes:
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
ax = create_ax(ax=ax, figsize=figsize)
ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
spec,
np.linspace(start_time, end_time, spec.shape[-1], endpoint=False),
np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False),
spec.numpy(),
cmap=cmap,
)
return ax

View File

@ -400,8 +400,6 @@ AugmentationConfig = Annotated[
class AugmentationsConfig(BaseConfig):
"""Configuration for a sequence of data augmentations."""
enabled: bool = True
steps: List[AugmentationConfig] = Field(default_factory=list)

View File

@ -67,8 +67,8 @@ class TrainingConfig(BaseConfig):
t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
loss: LossConfig = Field(default_factory=LossConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
augmentations: Optional[AugmentationsConfig] = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)

View File

@ -62,7 +62,7 @@ class LabelConfig(BaseConfig):
diffuse targets.
"""
sigma: float = 2.0
sigma: float = 3.0
def build_clip_labeler(
@ -174,7 +174,7 @@ def generate_clip_label(
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
return int(np.floor(np.interp(x, [min_val, max_val], [0, size])))
def generate_heatmaps(

View File

@ -239,7 +239,7 @@ def build_train_dataset(
clipper=clipper,
)
if config.augmentations.enabled and config.augmentations.steps:
if config.augmentations and config.augmentations.steps:
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,