diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 6e64542..ebc380e 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -225,7 +225,7 @@ class ConvBlock(nn.Module): kernel_size=kernel_size, padding=pad_size, ) - self.conv_bn = nn.BatchNorm2d(out_channels) + self.batch_norm = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply Conv -> BN -> ReLU. @@ -240,7 +240,7 @@ class ConvBlock(nn.Module): torch.Tensor Output tensor, shape `(B, C_out, H, W)`. """ - return F.relu_(self.conv_bn(self.conv(x))) + return F.relu_(self.batch_norm(self.conv(x))) class VerticalConv(nn.Module): @@ -364,7 +364,7 @@ class FreqCoordConvDownBlock(nn.Module): padding=pad_size, stride=1, ) - self.conv_bn = nn.BatchNorm2d(out_channels) + self.batch_norm = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply CoordF -> Conv -> MaxPool -> BN -> ReLU. @@ -383,7 +383,7 @@ class FreqCoordConvDownBlock(nn.Module): freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) x = torch.cat((x, freq_info), 1) x = F.max_pool2d(self.conv(x), 2, 2) - x = F.relu(self.conv_bn(x), inplace=True) + x = F.relu(self.batch_norm(x), inplace=True) return x @@ -438,7 +438,7 @@ class StandardConvDownBlock(nn.Module): padding=pad_size, stride=1, ) - self.conv_bn = nn.BatchNorm2d(out_channels) + self.batch_norm = nn.BatchNorm2d(out_channels) def forward(self, x): """Apply Conv -> MaxPool -> BN -> ReLU. @@ -454,7 +454,7 @@ class StandardConvDownBlock(nn.Module): Output tensor, shape `(B, C_out, H/2, W/2)`. """ x = F.max_pool2d(self.conv(x), 2, 2) - return F.relu(self.conv_bn(x), inplace=True) + return F.relu(self.batch_norm(x), inplace=True) class FreqCoordConvUpConfig(BaseConfig): @@ -534,7 +534,7 @@ class FreqCoordConvUpBlock(nn.Module): kernel_size=kernel_size, padding=pad_size, ) - self.conv_bn = nn.BatchNorm2d(out_channels) + self.batch_norm = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU. @@ -562,7 +562,7 @@ class FreqCoordConvUpBlock(nn.Module): freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3]) op = torch.cat((op, freq_info), 1) op = self.conv(op) - op = F.relu(self.conv_bn(op), inplace=True) + op = F.relu(self.batch_norm(op), inplace=True) return op @@ -625,7 +625,7 @@ class StandardConvUpBlock(nn.Module): kernel_size=kernel_size, padding=pad_size, ) - self.conv_bn = nn.BatchNorm2d(out_channels) + self.batch_norm = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply Interpolate -> Conv -> BN -> ReLU. @@ -650,7 +650,7 @@ class StandardConvUpBlock(nn.Module): align_corners=False, ) op = self.conv(op) - op = F.relu(self.conv_bn(op), inplace=True) + op = F.relu(self.batch_norm(op), inplace=True) return op diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index b1c0a3a..fec9c91 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -51,7 +51,7 @@ __all__ = [ DEFAULT_DETECTION_THRESHOLD = 0.01 -TOP_K_PER_SEC = 200 +TOP_K_PER_SEC = 100 class PostprocessConfig(BaseConfig): @@ -206,11 +206,13 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): if clips is None: return detections + width = output.detection_probs.shape[-1] + duration = width / self.samplerate return [ map_detection_to_clip( detection, start_time=clip.start_time, - end_time=clip.end_time, + end_time=clip.start_time + duration, min_freq=self.min_freq, max_freq=self.max_freq, ) @@ -220,9 +222,9 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): def get_raw_predictions( output: ModelOutput, - clips: List[data.Clip], targets: TargetProtocol, postprocessor: PostprocessorProtocol, + clips: Optional[List[data.Clip]] = None, ) -> List[List[RawPrediction]]: """Extract intermediate RawPrediction objects for a batch. @@ -259,9 +261,9 @@ def get_sound_event_predictions( ) -> List[List[BatDetect2Prediction]]: raw_predictions = get_raw_predictions( output, - clips, targets=targets, postprocessor=postprocessor, + clips=clips, ) return [ [ @@ -308,9 +310,9 @@ def get_predictions( """ raw_predictions = get_raw_predictions( output, - clips, targets=targets, postprocessor=postprocessor, + clips=clips, ) return [ convert_raw_predictions_to_clip_prediction( diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 4c73e1a..2f660f0 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -14,7 +14,7 @@ from batdetect2.data.conditions import ( SoundEventConditionConfig, build_sound_event_condition, ) -from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig +from batdetect2.targets.rois import ROIMapperConfig from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder __all__ = [ @@ -140,7 +140,6 @@ DEFAULT_CLASSES = [ TargetClassConfig( name="rhihip", tags=[data.Tag(key="class", value="Rhinolophus hipposideros")], - roi=AnchorBBoxMapperConfig(anchor="top-left"), ), TargetClassConfig( name="nyclei", @@ -149,7 +148,6 @@ DEFAULT_CLASSES = [ TargetClassConfig( name="rhifer", tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")], - roi=AnchorBBoxMapperConfig(anchor="top-left"), ), TargetClassConfig( name="pleaur", diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index bcebfcb..e4d586a 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -52,7 +52,7 @@ class ValLoaderConfig(BaseConfig): num_workers: int = 0 clipping_strategy: ClipConfig = Field( - default_factory=lambda: RandomClipConfig() + default_factory=lambda: PaddedClipConfig() ) diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 163e787..868738b 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -14,7 +14,8 @@ from loguru import logger from soundevent import data from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets import iterate_encoded_sound_events +from batdetect2.preprocess import MAX_FREQ, MIN_FREQ +from batdetect2.targets import build_targets, iterate_encoded_sound_events from batdetect2.typing import ( ClipLabeller, Heatmaps, @@ -45,9 +46,9 @@ class LabelConfig(BaseConfig): def build_clip_labeler( - targets: TargetProtocol, - min_freq: float, - max_freq: float, + targets: Optional[TargetProtocol] = None, + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, config: Optional[LabelConfig] = None, ) -> ClipLabeller: """Construct the final clip labelling function.""" @@ -56,6 +57,10 @@ def build_clip_labeler( "Building clip labeler with config: \n{}", lambda: config.to_yaml_string(), ) + + if targets is None: + targets = build_targets() + return partial( generate_heatmaps, targets=targets, diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index dfb4770..660a2f1 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -226,9 +226,9 @@ def build_trainer( def build_train_loader( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: AudioLoader, - labeller: ClipLabeller, - preprocessor: PreprocessorProtocol, + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, config: Optional[TrainLoaderConfig] = None, num_workers: Optional[int] = None, ) -> DataLoader: @@ -260,9 +260,9 @@ def build_train_loader( def build_val_loader( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: AudioLoader, - labeller: ClipLabeller, - preprocessor: PreprocessorProtocol, + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, config: Optional[ValLoaderConfig] = None, num_workers: Optional[int] = None, ): @@ -293,9 +293,9 @@ def build_val_loader( def build_train_dataset( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: AudioLoader, - labeller: ClipLabeller, - preprocessor: PreprocessorProtocol, + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, config: Optional[TrainLoaderConfig] = None, ) -> TrainingDataset: logger.info("Building training dataset...") @@ -303,6 +303,18 @@ def build_train_dataset( clipper = build_clipper(config=config.clipping_strategy) + if audio_loader is None: + audio_loader = build_audio_loader() + + if preprocessor is None: + preprocessor = build_preprocessor() + + if labeller is None: + labeller = build_clip_labeler( + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + random_example_source = RandomAudioSource( clip_annotations, audio_loader=audio_loader, @@ -332,14 +344,26 @@ def build_train_dataset( def build_val_dataset( clip_annotations: Sequence[data.ClipAnnotation], - audio_loader: AudioLoader, - labeller: ClipLabeller, - preprocessor: PreprocessorProtocol, + audio_loader: Optional[AudioLoader] = None, + labeller: Optional[ClipLabeller] = None, + preprocessor: Optional[PreprocessorProtocol] = None, config: Optional[ValLoaderConfig] = None, ) -> ValidationDataset: logger.info("Building validation dataset...") config = config or ValLoaderConfig() + if audio_loader is None: + audio_loader = build_audio_loader() + + if preprocessor is None: + preprocessor = build_preprocessor() + + if labeller is None: + labeller = build_clip_labeler( + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + clipper = build_clipper(config.clipping_strategy) return ValidationDataset( clip_annotations,