mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
3 Commits
4fd2e84773
...
ad5293e0d0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad5293e0d0 | ||
|
|
01e7a5df25 | ||
|
|
6d70140bc9 |
@ -89,18 +89,9 @@ def annotation_to_sound_event(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(key=label_key, value=annotation.label),
|
||||||
key=label_key, # type: ignore
|
data.Tag(key=event_key, value=annotation.event),
|
||||||
value=annotation.label,
|
data.Tag(key=individual_key, value=str(annotation.individual)),
|
||||||
),
|
|
||||||
data.Tag(
|
|
||||||
key=event_key, # type: ignore
|
|
||||||
value=annotation.event,
|
|
||||||
),
|
|
||||||
data.Tag(
|
|
||||||
key=individual_key, # type: ignore
|
|
||||||
value=str(annotation.individual),
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,12 +112,7 @@ def file_annotation_to_clip(
|
|||||||
recording = data.Recording.from_file(
|
recording = data.Recording.from_file(
|
||||||
full_path,
|
full_path,
|
||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
tags=[
|
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||||
data.Tag(
|
|
||||||
key=label_key, # type: ignore
|
|
||||||
value=file_annotation.label,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return data.Clip(
|
return data.Clip(
|
||||||
@ -153,12 +139,7 @@ def file_annotation_to_clip_annotation(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||||
clip=clip,
|
clip=clip,
|
||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[
|
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||||
data.Tag(
|
|
||||||
key=label_key, # type: ignore
|
|
||||||
value=file_annotation.label,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
sound_events=[
|
sound_events=[
|
||||||
annotation_to_sound_event(
|
annotation_to_sound_event(
|
||||||
annotation,
|
annotation,
|
||||||
|
|||||||
@ -57,6 +57,7 @@ class MatchConfig(BaseConfig):
|
|||||||
affinity_threshold: float = 0.0
|
affinity_threshold: float = 0.0
|
||||||
time_buffer: float = 0.005
|
time_buffer: float = 0.005
|
||||||
frequency_buffer: float = 1_000
|
frequency_buffer: float = 1_000
|
||||||
|
ignore_start_end: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||||
@ -273,6 +274,17 @@ def greedy_match(
|
|||||||
yield None, target_idx, 0
|
yield None, target_idx, 0
|
||||||
|
|
||||||
|
|
||||||
|
def _is_in_bounds(
|
||||||
|
geometry: data.Geometry,
|
||||||
|
clip: data.Clip,
|
||||||
|
buffer: float,
|
||||||
|
) -> bool:
|
||||||
|
start_time = compute_bounds(geometry)[0]
|
||||||
|
return (start_time >= clip.start_time + buffer) and (
|
||||||
|
start_time <= clip.end_time - buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
def match_sound_events_and_raw_predictions(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
raw_predictions: List[RawPrediction],
|
raw_predictions: List[RawPrediction],
|
||||||
@ -286,14 +298,29 @@ def match_sound_events_and_raw_predictions(
|
|||||||
for sound_event_annotation in clip_annotation.sound_events
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
if targets.filter(sound_event_annotation)
|
if targets.filter(sound_event_annotation)
|
||||||
and sound_event_annotation.sound_event.geometry is not None
|
and sound_event_annotation.sound_event.geometry is not None
|
||||||
|
and _is_in_bounds(
|
||||||
|
sound_event_annotation.sound_event.geometry,
|
||||||
|
clip=clip_annotation.clip,
|
||||||
|
buffer=config.ignore_start_end,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
target_geometries: List[data.Geometry] = [
|
||||||
sound_event_annotation.sound_event.geometry
|
sound_event_annotation.sound_event.geometry
|
||||||
for sound_event_annotation in target_sound_events
|
for sound_event_annotation in target_sound_events
|
||||||
if sound_event_annotation.sound_event.geometry is not None
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
|
raw_predictions = [
|
||||||
|
raw_prediction
|
||||||
|
for raw_prediction in raw_predictions
|
||||||
|
if _is_in_bounds(
|
||||||
|
raw_prediction.geometry,
|
||||||
|
clip=clip_annotation.clip,
|
||||||
|
buffer=config.ignore_start_end,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
predicted_geometries = [
|
predicted_geometries = [
|
||||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||||
]
|
]
|
||||||
|
|||||||
@ -225,7 +225,7 @@ class ConvBlock(nn.Module):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply Conv -> BN -> ReLU.
|
"""Apply Conv -> BN -> ReLU.
|
||||||
@ -240,7 +240,7 @@ class ConvBlock(nn.Module):
|
|||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output tensor, shape `(B, C_out, H, W)`.
|
Output tensor, shape `(B, C_out, H, W)`.
|
||||||
"""
|
"""
|
||||||
return F.relu_(self.conv_bn(self.conv(x)))
|
return F.relu_(self.batch_norm(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
class VerticalConv(nn.Module):
|
class VerticalConv(nn.Module):
|
||||||
@ -364,7 +364,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
||||||
@ -383,7 +383,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||||
x = torch.cat((x, freq_info), 1)
|
x = torch.cat((x, freq_info), 1)
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
x = F.relu(self.conv_bn(x), inplace=True)
|
x = F.relu(self.batch_norm(x), inplace=True)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -438,7 +438,7 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
||||||
@ -454,7 +454,7 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
||||||
"""
|
"""
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
return F.relu(self.conv_bn(x), inplace=True)
|
return F.relu(self.batch_norm(x), inplace=True)
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvUpConfig(BaseConfig):
|
class FreqCoordConvUpConfig(BaseConfig):
|
||||||
@ -534,7 +534,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
||||||
@ -562,7 +562,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
||||||
op = torch.cat((op, freq_info), 1)
|
op = torch.cat((op, freq_info), 1)
|
||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.conv_bn(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
@ -625,7 +625,7 @@ class StandardConvUpBlock(nn.Module):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
||||||
@ -650,7 +650,7 @@ class StandardConvUpBlock(nn.Module):
|
|||||||
align_corners=False,
|
align_corners=False,
|
||||||
)
|
)
|
||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.conv_bn(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -32,9 +32,12 @@ def plot_spectrogram(
|
|||||||
max_freq: Optional[float] = None,
|
max_freq: Optional[float] = None,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_colorbar: bool = False,
|
||||||
|
colorbar_kwargs: Optional[dict] = None,
|
||||||
|
vmin: Optional[float] = None,
|
||||||
|
vmax: Optional[float] = None,
|
||||||
cmap="gray",
|
cmap="gray",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
|
|
||||||
if isinstance(spec, torch.Tensor):
|
if isinstance(spec, torch.Tensor):
|
||||||
spec = spec.numpy()
|
spec = spec.numpy()
|
||||||
|
|
||||||
@ -54,10 +57,16 @@ def plot_spectrogram(
|
|||||||
if max_freq is None:
|
if max_freq is None:
|
||||||
max_freq = spec.shape[-2]
|
max_freq = spec.shape[-2]
|
||||||
|
|
||||||
ax.pcolormesh(
|
mappable = ax.pcolormesh(
|
||||||
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
|
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
|
||||||
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
|
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
|
||||||
spec,
|
spec,
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
|
vmin=vmin,
|
||||||
|
vmax=vmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_colorbar:
|
||||||
|
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -136,7 +136,7 @@ def plot_class_examples(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError):
|
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|||||||
@ -51,7 +51,7 @@ __all__ = [
|
|||||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||||
|
|
||||||
|
|
||||||
TOP_K_PER_SEC = 200
|
TOP_K_PER_SEC = 100
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseConfig):
|
class PostprocessConfig(BaseConfig):
|
||||||
@ -206,11 +206,13 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
if clips is None:
|
if clips is None:
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
width = output.detection_probs.shape[-1]
|
||||||
|
duration = width / self.samplerate
|
||||||
return [
|
return [
|
||||||
map_detection_to_clip(
|
map_detection_to_clip(
|
||||||
detection,
|
detection,
|
||||||
start_time=clip.start_time,
|
start_time=clip.start_time,
|
||||||
end_time=clip.end_time,
|
end_time=clip.start_time + duration,
|
||||||
min_freq=self.min_freq,
|
min_freq=self.min_freq,
|
||||||
max_freq=self.max_freq,
|
max_freq=self.max_freq,
|
||||||
)
|
)
|
||||||
@ -220,9 +222,9 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
|
|
||||||
def get_raw_predictions(
|
def get_raw_predictions(
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: List[data.Clip],
|
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
|
clips: Optional[List[data.Clip]] = None,
|
||||||
) -> List[List[RawPrediction]]:
|
) -> List[List[RawPrediction]]:
|
||||||
"""Extract intermediate RawPrediction objects for a batch.
|
"""Extract intermediate RawPrediction objects for a batch.
|
||||||
|
|
||||||
@ -259,9 +261,9 @@ def get_sound_event_predictions(
|
|||||||
) -> List[List[BatDetect2Prediction]]:
|
) -> List[List[BatDetect2Prediction]]:
|
||||||
raw_predictions = get_raw_predictions(
|
raw_predictions = get_raw_predictions(
|
||||||
output,
|
output,
|
||||||
clips,
|
|
||||||
targets=targets,
|
targets=targets,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
|
clips=clips,
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
[
|
[
|
||||||
@ -308,9 +310,9 @@ def get_predictions(
|
|||||||
"""
|
"""
|
||||||
raw_predictions = get_raw_predictions(
|
raw_predictions = get_raw_predictions(
|
||||||
output,
|
output,
|
||||||
clips,
|
|
||||||
targets=targets,
|
targets=targets,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
|
clips=clips,
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
convert_raw_predictions_to_clip_prediction(
|
convert_raw_predictions_to_clip_prediction(
|
||||||
|
|||||||
@ -28,12 +28,17 @@ from batdetect2.targets.rois import (
|
|||||||
ROITargetMapper,
|
ROITargetMapper,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import call_type, individual
|
from batdetect2.targets.terms import (
|
||||||
|
call_type,
|
||||||
|
data_source,
|
||||||
|
generic_class,
|
||||||
|
individual,
|
||||||
|
)
|
||||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_TARGET_CONFIG",
|
|
||||||
"AnchorBBoxMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
|
"DEFAULT_TARGET_CONFIG",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
@ -44,6 +49,8 @@ __all__ = [
|
|||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
"call_type",
|
"call_type",
|
||||||
|
"data_source",
|
||||||
|
"generic_class",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
"individual",
|
"individual",
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from batdetect2.data.conditions import (
|
|||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
build_sound_event_condition,
|
build_sound_event_condition,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
from batdetect2.targets.rois import ROIMapperConfig
|
||||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -140,7 +140,6 @@ DEFAULT_CLASSES = [
|
|||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="rhihip",
|
name="rhihip",
|
||||||
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="nyclei",
|
name="nyclei",
|
||||||
@ -149,7 +148,6 @@ DEFAULT_CLASSES = [
|
|||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="rhifer",
|
name="rhifer",
|
||||||
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
),
|
||||||
TargetClassConfig(
|
TargetClassConfig(
|
||||||
name="pleaur",
|
name="pleaur",
|
||||||
|
|||||||
@ -6,6 +6,7 @@ __all__ = [
|
|||||||
"call_type",
|
"call_type",
|
||||||
"individual",
|
"individual",
|
||||||
"data_source",
|
"data_source",
|
||||||
|
"generic_class",
|
||||||
]
|
]
|
||||||
|
|
||||||
# The default key used to reference the 'generic_class' term.
|
# The default key used to reference the 'generic_class' term.
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class ValLoaderConfig(BaseConfig):
|
|||||||
num_workers: int = 0
|
num_workers: int = 0
|
||||||
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
clipping_strategy: ClipConfig = Field(
|
||||||
default_factory=lambda: RandomClipConfig()
|
default_factory=lambda: PaddedClipConfig()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.targets import iterate_encoded_sound_events
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipLabeller,
|
ClipLabeller,
|
||||||
Heatmaps,
|
Heatmaps,
|
||||||
@ -45,9 +46,9 @@ class LabelConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
def build_clip_labeler(
|
def build_clip_labeler(
|
||||||
targets: TargetProtocol,
|
targets: Optional[TargetProtocol] = None,
|
||||||
min_freq: float,
|
min_freq: float = MIN_FREQ,
|
||||||
max_freq: float,
|
max_freq: float = MAX_FREQ,
|
||||||
config: Optional[LabelConfig] = None,
|
config: Optional[LabelConfig] = None,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
"""Construct the final clip labelling function."""
|
"""Construct the final clip labelling function."""
|
||||||
@ -56,6 +57,10 @@ def build_clip_labeler(
|
|||||||
"Building clip labeler with config: \n{}",
|
"Building clip labeler with config: \n{}",
|
||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if targets is None:
|
||||||
|
targets = build_targets()
|
||||||
|
|
||||||
return partial(
|
return partial(
|
||||||
generate_heatmaps,
|
generate_heatmaps,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
|||||||
@ -226,9 +226,9 @@ def build_trainer(
|
|||||||
|
|
||||||
def build_train_loader(
|
def build_train_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
labeller: ClipLabeller,
|
labeller: Optional[ClipLabeller] = None,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: Optional[TrainLoaderConfig] = None,
|
config: Optional[TrainLoaderConfig] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
@ -260,9 +260,9 @@ def build_train_loader(
|
|||||||
|
|
||||||
def build_val_loader(
|
def build_val_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
labeller: ClipLabeller,
|
labeller: Optional[ClipLabeller] = None,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: Optional[ValLoaderConfig] = None,
|
config: Optional[ValLoaderConfig] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@ -293,9 +293,9 @@ def build_val_loader(
|
|||||||
|
|
||||||
def build_train_dataset(
|
def build_train_dataset(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
labeller: ClipLabeller,
|
labeller: Optional[ClipLabeller] = None,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: Optional[TrainLoaderConfig] = None,
|
config: Optional[TrainLoaderConfig] = None,
|
||||||
) -> TrainingDataset:
|
) -> TrainingDataset:
|
||||||
logger.info("Building training dataset...")
|
logger.info("Building training dataset...")
|
||||||
@ -303,6 +303,18 @@ def build_train_dataset(
|
|||||||
|
|
||||||
clipper = build_clipper(config=config.clipping_strategy)
|
clipper = build_clipper(config=config.clipping_strategy)
|
||||||
|
|
||||||
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
if labeller is None:
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
random_example_source = RandomAudioSource(
|
random_example_source = RandomAudioSource(
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
@ -332,14 +344,26 @@ def build_train_dataset(
|
|||||||
|
|
||||||
def build_val_dataset(
|
def build_val_dataset(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
labeller: ClipLabeller,
|
labeller: Optional[ClipLabeller] = None,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: Optional[ValLoaderConfig] = None,
|
config: Optional[ValLoaderConfig] = None,
|
||||||
) -> ValidationDataset:
|
) -> ValidationDataset:
|
||||||
logger.info("Building validation dataset...")
|
logger.info("Building validation dataset...")
|
||||||
config = config or ValLoaderConfig()
|
config = config or ValLoaderConfig()
|
||||||
|
|
||||||
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
if labeller is None:
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
clipper = build_clipper(config.clipping_strategy)
|
clipper = build_clipper(config.clipping_strategy)
|
||||||
return ValidationDataset(
|
return ValidationDataset(
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
|
|||||||
@ -47,29 +47,7 @@ class GeometryDecoder(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
class RawPrediction(NamedTuple):
|
||||||
"""Intermediate representation of a single detected sound event.
|
"""Intermediate representation of a single detected sound event."""
|
||||||
|
|
||||||
Holds extracted information about a detection after initial processing
|
|
||||||
(like peak finding, coordinate remapping, geometry recovery) but before
|
|
||||||
final class decoding and conversion into a `SoundEventPrediction`. This
|
|
||||||
can be useful for evaluation or simpler data handling formats.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
geometry: data.Geometry
|
|
||||||
The recovered estimated geometry of the detected sound event.
|
|
||||||
Usually a bounding box.
|
|
||||||
detection_score : float
|
|
||||||
The confidence score associated with this detection, typically from
|
|
||||||
the detection heatmap peak.
|
|
||||||
class_scores : xr.DataArray
|
|
||||||
An xarray DataArray containing the predicted probabilities or scores
|
|
||||||
for each target class at the detection location. Indexed by a
|
|
||||||
'category' coordinate containing class names.
|
|
||||||
features : xr.DataArray
|
|
||||||
An xarray DataArray containing extracted feature vectors at the
|
|
||||||
detection location. Indexed by a 'feature' coordinate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
geometry: data.Geometry
|
geometry: data.Geometry
|
||||||
detection_score: float
|
detection_score: float
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user