Compare commits

..

3 Commits

Author SHA1 Message Date
mbsantiago
ad5293e0d0 Ad FileNotFoundError to plotting 2025-09-13 20:05:42 +01:00
mbsantiago
01e7a5df25 Add ignore at ends when evaluating 2025-09-13 19:03:40 +01:00
mbsantiago
6d70140bc9 Default to normal anchor 2025-09-13 13:56:47 +01:00
13 changed files with 120 additions and 88 deletions

View File

@ -89,18 +89,9 @@ def annotation_to_sound_event(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag( data.Tag(key=label_key, value=annotation.label),
key=label_key, # type: ignore data.Tag(key=event_key, value=annotation.event),
value=annotation.label, data.Tag(key=individual_key, value=str(annotation.individual)),
),
data.Tag(
key=event_key, # type: ignore
value=annotation.event,
),
data.Tag(
key=individual_key, # type: ignore
value=str(annotation.individual),
),
], ],
) )
@ -121,12 +112,7 @@ def file_annotation_to_clip(
recording = data.Recording.from_file( recording = data.Recording.from_file(
full_path, full_path,
time_expansion=file_annotation.time_exp, time_expansion=file_annotation.time_exp,
tags=[ tags=[data.Tag(key=label_key, value=file_annotation.label)],
data.Tag(
key=label_key, # type: ignore
value=file_annotation.label,
)
],
) )
return data.Clip( return data.Clip(
@ -153,12 +139,7 @@ def file_annotation_to_clip_annotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"), uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
clip=clip, clip=clip,
notes=notes, notes=notes,
tags=[ tags=[data.Tag(key=label_key, value=file_annotation.label)],
data.Tag(
key=label_key, # type: ignore
value=file_annotation.label,
)
],
sound_events=[ sound_events=[
annotation_to_sound_event( annotation_to_sound_event(
annotation, annotation,

View File

@ -57,6 +57,7 @@ class MatchConfig(BaseConfig):
affinity_threshold: float = 0.0 affinity_threshold: float = 0.0
time_buffer: float = 0.005 time_buffer: float = 0.005
frequency_buffer: float = 1_000 frequency_buffer: float = 1_000
ignore_start_end: float = 0.01
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox: def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
@ -273,6 +274,17 @@ def greedy_match(
yield None, target_idx, 0 yield None, target_idx, 0
def _is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)
def match_sound_events_and_raw_predictions( def match_sound_events_and_raw_predictions(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
raw_predictions: List[RawPrediction], raw_predictions: List[RawPrediction],
@ -286,14 +298,29 @@ def match_sound_events_and_raw_predictions(
for sound_event_annotation in clip_annotation.sound_events for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation) if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None and sound_event_annotation.sound_event.geometry is not None
and _is_in_bounds(
sound_event_annotation.sound_event.geometry,
clip=clip_annotation.clip,
buffer=config.ignore_start_end,
)
] ]
target_geometries: List[data.Geometry] = [ # type: ignore target_geometries: List[data.Geometry] = [
sound_event_annotation.sound_event.geometry sound_event_annotation.sound_event.geometry
for sound_event_annotation in target_sound_events for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None if sound_event_annotation.sound_event.geometry is not None
] ]
raw_predictions = [
raw_prediction
for raw_prediction in raw_predictions
if _is_in_bounds(
raw_prediction.geometry,
clip=clip_annotation.clip,
buffer=config.ignore_start_end,
)
]
predicted_geometries = [ predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions raw_prediction.geometry for raw_prediction in raw_predictions
] ]

View File

@ -225,7 +225,7 @@ class ConvBlock(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
padding=pad_size, padding=pad_size,
) )
self.conv_bn = nn.BatchNorm2d(out_channels) self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Conv -> BN -> ReLU. """Apply Conv -> BN -> ReLU.
@ -240,7 +240,7 @@ class ConvBlock(nn.Module):
torch.Tensor torch.Tensor
Output tensor, shape `(B, C_out, H, W)`. Output tensor, shape `(B, C_out, H, W)`.
""" """
return F.relu_(self.conv_bn(self.conv(x))) return F.relu_(self.batch_norm(self.conv(x)))
class VerticalConv(nn.Module): class VerticalConv(nn.Module):
@ -364,7 +364,7 @@ class FreqCoordConvDownBlock(nn.Module):
padding=pad_size, padding=pad_size,
stride=1, stride=1,
) )
self.conv_bn = nn.BatchNorm2d(out_channels) self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU. """Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
@ -383,7 +383,7 @@ class FreqCoordConvDownBlock(nn.Module):
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
x = torch.cat((x, freq_info), 1) x = torch.cat((x, freq_info), 1)
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True) x = F.relu(self.batch_norm(x), inplace=True)
return x return x
@ -438,7 +438,7 @@ class StandardConvDownBlock(nn.Module):
padding=pad_size, padding=pad_size,
stride=1, stride=1,
) )
self.conv_bn = nn.BatchNorm2d(out_channels) self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x): def forward(self, x):
"""Apply Conv -> MaxPool -> BN -> ReLU. """Apply Conv -> MaxPool -> BN -> ReLU.
@ -454,7 +454,7 @@ class StandardConvDownBlock(nn.Module):
Output tensor, shape `(B, C_out, H/2, W/2)`. Output tensor, shape `(B, C_out, H/2, W/2)`.
""" """
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.conv_bn(x), inplace=True) return F.relu(self.batch_norm(x), inplace=True)
class FreqCoordConvUpConfig(BaseConfig): class FreqCoordConvUpConfig(BaseConfig):
@ -534,7 +534,7 @@ class FreqCoordConvUpBlock(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
padding=pad_size, padding=pad_size,
) )
self.conv_bn = nn.BatchNorm2d(out_channels) self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU. """Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
@ -562,7 +562,7 @@ class FreqCoordConvUpBlock(nn.Module):
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3]) freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
op = torch.cat((op, freq_info), 1) op = torch.cat((op, freq_info), 1)
op = self.conv(op) op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op
@ -625,7 +625,7 @@ class StandardConvUpBlock(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
padding=pad_size, padding=pad_size,
) )
self.conv_bn = nn.BatchNorm2d(out_channels) self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Conv -> BN -> ReLU. """Apply Interpolate -> Conv -> BN -> ReLU.
@ -650,7 +650,7 @@ class StandardConvUpBlock(nn.Module):
align_corners=False, align_corners=False,
) )
op = self.conv(op) op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op

View File

@ -32,9 +32,12 @@ def plot_spectrogram(
max_freq: Optional[float] = None, max_freq: Optional[float] = None,
ax: Optional[axes.Axes] = None, ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
add_colorbar: bool = False,
colorbar_kwargs: Optional[dict] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
cmap="gray", cmap="gray",
) -> axes.Axes: ) -> axes.Axes:
if isinstance(spec, torch.Tensor): if isinstance(spec, torch.Tensor):
spec = spec.numpy() spec = spec.numpy()
@ -54,10 +57,16 @@ def plot_spectrogram(
if max_freq is None: if max_freq is None:
max_freq = spec.shape[-2] max_freq = spec.shape[-2]
ax.pcolormesh( mappable = ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True), np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True), np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
spec, spec,
cmap=cmap, cmap=cmap,
vmin=vmin,
vmax=vmax,
) )
if add_colorbar:
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
return ax return ax

View File

@ -136,7 +136,7 @@ def plot_class_examples(
preprocessor=preprocessor, preprocessor=preprocessor,
duration=duration, duration=duration,
) )
except (ValueError, AssertionError, RuntimeError): except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
continue continue
return fig return fig

View File

@ -51,7 +51,7 @@ __all__ = [
DEFAULT_DETECTION_THRESHOLD = 0.01 DEFAULT_DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200 TOP_K_PER_SEC = 100
class PostprocessConfig(BaseConfig): class PostprocessConfig(BaseConfig):
@ -206,11 +206,13 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
if clips is None: if clips is None:
return detections return detections
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
return [ return [
map_detection_to_clip( map_detection_to_clip(
detection, detection,
start_time=clip.start_time, start_time=clip.start_time,
end_time=clip.end_time, end_time=clip.start_time + duration,
min_freq=self.min_freq, min_freq=self.min_freq,
max_freq=self.max_freq, max_freq=self.max_freq,
) )
@ -220,9 +222,9 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def get_raw_predictions( def get_raw_predictions(
output: ModelOutput, output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol, targets: TargetProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
clips: Optional[List[data.Clip]] = None,
) -> List[List[RawPrediction]]: ) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch. """Extract intermediate RawPrediction objects for a batch.
@ -259,9 +261,9 @@ def get_sound_event_predictions(
) -> List[List[BatDetect2Prediction]]: ) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions( raw_predictions = get_raw_predictions(
output, output,
clips,
targets=targets, targets=targets,
postprocessor=postprocessor, postprocessor=postprocessor,
clips=clips,
) )
return [ return [
[ [
@ -308,9 +310,9 @@ def get_predictions(
""" """
raw_predictions = get_raw_predictions( raw_predictions = get_raw_predictions(
output, output,
clips,
targets=targets, targets=targets,
postprocessor=postprocessor, postprocessor=postprocessor,
clips=clips,
) )
return [ return [
convert_raw_predictions_to_clip_prediction( convert_raw_predictions_to_clip_prediction(

View File

@ -28,12 +28,17 @@ from batdetect2.targets.rois import (
ROITargetMapper, ROITargetMapper,
build_roi_mapper, build_roi_mapper,
) )
from batdetect2.targets.terms import call_type, individual from batdetect2.targets.terms import (
call_type,
data_source,
generic_class,
individual,
)
from batdetect2.typing.targets import Position, Size, TargetProtocol from batdetect2.typing.targets import Position, Size, TargetProtocol
__all__ = [ __all__ = [
"DEFAULT_TARGET_CONFIG",
"AnchorBBoxMapperConfig", "AnchorBBoxMapperConfig",
"DEFAULT_TARGET_CONFIG",
"ROITargetMapper", "ROITargetMapper",
"SoundEventDecoder", "SoundEventDecoder",
"SoundEventEncoder", "SoundEventEncoder",
@ -44,6 +49,8 @@ __all__ = [
"build_sound_event_decoder", "build_sound_event_decoder",
"build_sound_event_encoder", "build_sound_event_encoder",
"call_type", "call_type",
"data_source",
"generic_class",
"get_class_names_from_config", "get_class_names_from_config",
"individual", "individual",
"load_target_config", "load_target_config",

View File

@ -14,7 +14,7 @@ from batdetect2.data.conditions import (
SoundEventConditionConfig, SoundEventConditionConfig,
build_sound_event_condition, build_sound_event_condition,
) )
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
__all__ = [ __all__ = [
@ -140,7 +140,6 @@ DEFAULT_CLASSES = [
TargetClassConfig( TargetClassConfig(
name="rhihip", name="rhihip",
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")], tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
roi=AnchorBBoxMapperConfig(anchor="top-left"),
), ),
TargetClassConfig( TargetClassConfig(
name="nyclei", name="nyclei",
@ -149,7 +148,6 @@ DEFAULT_CLASSES = [
TargetClassConfig( TargetClassConfig(
name="rhifer", name="rhifer",
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")], tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
roi=AnchorBBoxMapperConfig(anchor="top-left"),
), ),
TargetClassConfig( TargetClassConfig(
name="pleaur", name="pleaur",

View File

@ -6,6 +6,7 @@ __all__ = [
"call_type", "call_type",
"individual", "individual",
"data_source", "data_source",
"generic_class",
] ]
# The default key used to reference the 'generic_class' term. # The default key used to reference the 'generic_class' term.

View File

@ -52,7 +52,7 @@ class ValLoaderConfig(BaseConfig):
num_workers: int = 0 num_workers: int = 0
clipping_strategy: ClipConfig = Field( clipping_strategy: ClipConfig = Field(
default_factory=lambda: RandomClipConfig() default_factory=lambda: PaddedClipConfig()
) )

View File

@ -14,7 +14,8 @@ from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets import iterate_encoded_sound_events from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.typing import ( from batdetect2.typing import (
ClipLabeller, ClipLabeller,
Heatmaps, Heatmaps,
@ -45,9 +46,9 @@ class LabelConfig(BaseConfig):
def build_clip_labeler( def build_clip_labeler(
targets: TargetProtocol, targets: Optional[TargetProtocol] = None,
min_freq: float, min_freq: float = MIN_FREQ,
max_freq: float, max_freq: float = MAX_FREQ,
config: Optional[LabelConfig] = None, config: Optional[LabelConfig] = None,
) -> ClipLabeller: ) -> ClipLabeller:
"""Construct the final clip labelling function.""" """Construct the final clip labelling function."""
@ -56,6 +57,10 @@ def build_clip_labeler(
"Building clip labeler with config: \n{}", "Building clip labeler with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
if targets is None:
targets = build_targets()
return partial( return partial(
generate_heatmaps, generate_heatmaps,
targets=targets, targets=targets,

View File

@ -226,9 +226,9 @@ def build_trainer(
def build_train_loader( def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: Optional[AudioLoader] = None,
labeller: ClipLabeller, labeller: Optional[ClipLabeller] = None,
preprocessor: PreprocessorProtocol, preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None, config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
@ -260,9 +260,9 @@ def build_train_loader(
def build_val_loader( def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: Optional[AudioLoader] = None,
labeller: ClipLabeller, labeller: Optional[ClipLabeller] = None,
preprocessor: PreprocessorProtocol, preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None, config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
): ):
@ -293,9 +293,9 @@ def build_val_loader(
def build_train_dataset( def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: Optional[AudioLoader] = None,
labeller: ClipLabeller, labeller: Optional[ClipLabeller] = None,
preprocessor: PreprocessorProtocol, preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None, config: Optional[TrainLoaderConfig] = None,
) -> TrainingDataset: ) -> TrainingDataset:
logger.info("Building training dataset...") logger.info("Building training dataset...")
@ -303,6 +303,18 @@ def build_train_dataset(
clipper = build_clipper(config=config.clipping_strategy) clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
random_example_source = RandomAudioSource( random_example_source = RandomAudioSource(
clip_annotations, clip_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
@ -332,14 +344,26 @@ def build_train_dataset(
def build_val_dataset( def build_val_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: Optional[AudioLoader] = None,
labeller: ClipLabeller, labeller: Optional[ClipLabeller] = None,
preprocessor: PreprocessorProtocol, preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None, config: Optional[ValLoaderConfig] = None,
) -> ValidationDataset: ) -> ValidationDataset:
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or ValLoaderConfig() config = config or ValLoaderConfig()
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
clipper = build_clipper(config.clipping_strategy) clipper = build_clipper(config.clipping_strategy)
return ValidationDataset( return ValidationDataset(
clip_annotations, clip_annotations,

View File

@ -47,29 +47,7 @@ class GeometryDecoder(Protocol):
class RawPrediction(NamedTuple): class RawPrediction(NamedTuple):
"""Intermediate representation of a single detected sound event. """Intermediate representation of a single detected sound event."""
Holds extracted information about a detection after initial processing
(like peak finding, coordinate remapping, geometry recovery) but before
final class decoding and conversion into a `SoundEventPrediction`. This
can be useful for evaluation or simpler data handling formats.
Attributes
----------
geometry: data.Geometry
The recovered estimated geometry of the detected sound event.
Usually a bounding box.
detection_score : float
The confidence score associated with this detection, typically from
the detection heatmap peak.
class_scores : xr.DataArray
An xarray DataArray containing the predicted probabilities or scores
for each target class at the detection location. Indexed by a
'category' coordinate containing class names.
features : xr.DataArray
An xarray DataArray containing extracted feature vectors at the
detection location. Indexed by a 'feature' coordinate.
"""
geometry: data.Geometry geometry: data.Geometry
detection_score: float detection_score: float