Compare commits

..

2 Commits

Author SHA1 Message Date
mbsantiago
ff754a1269 Tweaks of augmentation config 2025-08-27 18:23:38 +01:00
mbsantiago
ed76ec24b6 Plot anchor points 2025-08-27 18:13:40 +01:00
6 changed files with 45 additions and 10 deletions

View File

@ -4,7 +4,9 @@ from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import create_ax
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"plot_clip_annotation", "plot_clip_annotation",
@ -43,3 +45,31 @@ def plot_clip_annotation(
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
) )
return ax return ax
def plot_anchor_points(
clip_annotation: data.ClipAnnotation,
targets: TargetProtocol,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
size: int = 1,
color: str = "red",
marker: str = "x",
alpha: float = 1,
) -> Axes:
ax = create_ax(ax=ax, figsize=figsize)
positions = []
for sound_event in clip_annotation.sound_events:
if not targets.filter(sound_event):
continue
sound_event = targets.transform(sound_event)
position, _ = targets.encode_roi(sound_event)
positions.append(position)
X, Y = zip(*positions)
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
return ax

View File

@ -1,6 +1,6 @@
"""General plotting utilities.""" """General plotting utilities."""
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -25,7 +25,7 @@ def create_ax(
def plot_spectrogram( def plot_spectrogram(
spec: torch.Tensor, spec: Union[torch.Tensor, np.ndarray],
start_time: float, start_time: float,
end_time: float, end_time: float,
min_freq: float, min_freq: float,
@ -34,12 +34,15 @@ def plot_spectrogram(
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
cmap="gray", cmap="gray",
) -> axes.Axes: ) -> axes.Axes:
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
ax.pcolormesh( ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1], endpoint=False), np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False), np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
spec.numpy(), spec,
cmap=cmap, cmap=cmap,
) )
return ax return ax

View File

@ -400,6 +400,8 @@ AugmentationConfig = Annotated[
class AugmentationsConfig(BaseConfig): class AugmentationsConfig(BaseConfig):
"""Configuration for a sequence of data augmentations.""" """Configuration for a sequence of data augmentations."""
enabled: bool = True
steps: List[AugmentationConfig] = Field(default_factory=list) steps: List[AugmentationConfig] = Field(default_factory=list)

View File

@ -67,8 +67,8 @@ class TrainingConfig(BaseConfig):
t_max: int = 100 t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig) dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
augmentations: Optional[AugmentationsConfig] = Field( augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
) )
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)

View File

@ -62,7 +62,7 @@ class LabelConfig(BaseConfig):
diffuse targets. diffuse targets.
""" """
sigma: float = 3.0 sigma: float = 2.0
def build_clip_labeler( def build_clip_labeler(
@ -174,7 +174,7 @@ def generate_clip_label(
def map_to_pixels(x, size, min_val, max_val) -> int: def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.floor(np.interp(x, [min_val, max_val], [0, size]))) return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps( def generate_heatmaps(

View File

@ -239,7 +239,7 @@ def build_train_dataset(
clipper=clipper, clipper=clipper,
) )
if config.augmentations and config.augmentations.steps: if config.augmentations.enabled and config.augmentations.steps:
augmentations = build_augmentations( augmentations = build_augmentations(
preprocessor, preprocessor,
config=config.augmentations, config=config.augmentations,