mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Default to normal anchor
This commit is contained in:
parent
4fd2e84773
commit
6d70140bc9
@ -225,7 +225,7 @@ class ConvBlock(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Conv -> BN -> ReLU.
|
||||
@ -240,7 +240,7 @@ class ConvBlock(nn.Module):
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H, W)`.
|
||||
"""
|
||||
return F.relu_(self.conv_bn(self.conv(x)))
|
||||
return F.relu_(self.batch_norm(self.conv(x)))
|
||||
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
@ -364,7 +364,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
padding=pad_size,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
||||
@ -383,7 +383,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
x = torch.cat((x, freq_info), 1)
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
x = F.relu(self.batch_norm(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
@ -438,7 +438,7 @@ class StandardConvDownBlock(nn.Module):
|
||||
padding=pad_size,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
||||
@ -454,7 +454,7 @@ class StandardConvDownBlock(nn.Module):
|
||||
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
||||
"""
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.conv_bn(x), inplace=True)
|
||||
return F.relu(self.batch_norm(x), inplace=True)
|
||||
|
||||
|
||||
class FreqCoordConvUpConfig(BaseConfig):
|
||||
@ -534,7 +534,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
||||
@ -562,7 +562,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
||||
op = torch.cat((op, freq_info), 1)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
@ -625,7 +625,7 @@ class StandardConvUpBlock(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
||||
@ -650,7 +650,7 @@ class StandardConvUpBlock(nn.Module):
|
||||
align_corners=False,
|
||||
)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ __all__ = [
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 200
|
||||
TOP_K_PER_SEC = 100
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
@ -206,11 +206,13 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
if clips is None:
|
||||
return detections
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
end_time=clip.start_time + duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
@ -220,9 +222,9 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
|
||||
def get_raw_predictions(
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
@ -259,9 +261,9 @@ def get_sound_event_predictions(
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
clips,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
clips=clips,
|
||||
)
|
||||
return [
|
||||
[
|
||||
@ -308,9 +310,9 @@ def get_predictions(
|
||||
"""
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
clips,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
clips=clips,
|
||||
)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
|
||||
@ -14,7 +14,7 @@ from batdetect2.data.conditions import (
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
||||
from batdetect2.targets.rois import ROIMapperConfig
|
||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||
|
||||
__all__ = [
|
||||
@ -140,7 +140,6 @@ DEFAULT_CLASSES = [
|
||||
TargetClassConfig(
|
||||
name="rhihip",
|
||||
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="nyclei",
|
||||
@ -149,7 +148,6 @@ DEFAULT_CLASSES = [
|
||||
TargetClassConfig(
|
||||
name="rhifer",
|
||||
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClassConfig(
|
||||
name="pleaur",
|
||||
|
||||
@ -52,7 +52,7 @@ class ValLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: RandomClipConfig()
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -14,7 +14,8 @@ from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets import iterate_encoded_sound_events
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||
from batdetect2.typing import (
|
||||
ClipLabeller,
|
||||
Heatmaps,
|
||||
@ -45,9 +46,9 @@ class LabelConfig(BaseConfig):
|
||||
|
||||
|
||||
def build_clip_labeler(
|
||||
targets: TargetProtocol,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
config: Optional[LabelConfig] = None,
|
||||
) -> ClipLabeller:
|
||||
"""Construct the final clip labelling function."""
|
||||
@ -56,6 +57,10 @@ def build_clip_labeler(
|
||||
"Building clip labeler with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
return partial(
|
||||
generate_heatmaps,
|
||||
targets=targets,
|
||||
|
||||
@ -226,9 +226,9 @@ def build_trainer(
|
||||
|
||||
def build_train_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TrainLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
@ -260,9 +260,9 @@ def build_train_loader(
|
||||
|
||||
def build_val_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[ValLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
@ -293,9 +293,9 @@ def build_val_loader(
|
||||
|
||||
def build_train_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TrainLoaderConfig] = None,
|
||||
) -> TrainingDataset:
|
||||
logger.info("Building training dataset...")
|
||||
@ -303,6 +303,18 @@ def build_train_dataset(
|
||||
|
||||
clipper = build_clipper(config=config.clipping_strategy)
|
||||
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
if labeller is None:
|
||||
labeller = build_clip_labeler(
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
random_example_source = RandomAudioSource(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
@ -332,14 +344,26 @@ def build_train_dataset(
|
||||
|
||||
def build_val_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[ValLoaderConfig] = None,
|
||||
) -> ValidationDataset:
|
||||
logger.info("Building validation dataset...")
|
||||
config = config or ValLoaderConfig()
|
||||
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
if labeller is None:
|
||||
labeller = build_clip_labeler(
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
clipper = build_clipper(config.clipping_strategy)
|
||||
return ValidationDataset(
|
||||
clip_annotations,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user