From b93d4c65c29aeff425e2b6157781f1822707348f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 12 Apr 2025 16:48:40 +0100 Subject: [PATCH] Create separate targets module --- batdetect2/cli/train.py | 4 +- batdetect2/compat/params.py | 19 +- batdetect2/data/__init__.py | 6 + batdetect2/data/annotations.py | 3 - batdetect2/data/annotations/types.py | 4 - batdetect2/evaluate/__init__.py | 9 + batdetect2/evaluate/evaluate.py | 23 +- batdetect2/models/__init__.py | 80 +--- batdetect2/models/backbones.py | 4 +- batdetect2/models/build.py | 58 +++ batdetect2/models/config.py | 34 ++ batdetect2/models/decoder.py | 15 - batdetect2/models/encoder.py | 15 - batdetect2/models/{typing.py => types.py} | 0 batdetect2/modules.py | 36 +- batdetect2/post_process.py | 2 +- batdetect2/postprocess/__init__.py | 0 batdetect2/postprocess/arrays.py | 73 ++++ batdetect2/postprocess/config.py | 32 ++ batdetect2/postprocess/non_max_supression.py | 50 +++ batdetect2/postprocess/types.py | 17 + batdetect2/preprocess/__init__.py | 4 + batdetect2/preprocess/audio.py | 8 +- batdetect2/preprocess/spectrogram.py | 22 +- batdetect2/targets/__init__.py | 49 +++ batdetect2/{train => targets}/labels.py | 0 batdetect2/{train => targets}/targets.py | 2 +- batdetect2/targets/terms.py | 370 +++++++++++++++++++ batdetect2/terms.py | 88 ----- batdetect2/train/__init__.py | 15 +- batdetect2/train/callbacks.py | 55 +++ batdetect2/train/legacy/train.py | 2 +- batdetect2/train/losses.py | 2 +- batdetect2/train/preprocess.py | 7 +- pyproject.toml | 12 +- tests/conftest.py | 110 +++++- tests/test_data/test_batdetect.py | 12 +- tests/test_migration/test_training.py | 20 +- tests/test_postprocessing/__init__.py | 0 tests/test_postprocessing/test_arrays.py | 43 +++ tests/test_targets/__init__.py | 0 tests/test_targets/test_terms.py | 175 +++++++++ tests/test_train/test_augmentations.py | 49 ++- tests/test_train/test_labels.py | 2 +- 44 files changed, 1227 insertions(+), 304 deletions(-) create mode 100644 batdetect2/models/build.py create mode 100644 batdetect2/models/config.py delete mode 100644 batdetect2/models/decoder.py delete mode 100644 batdetect2/models/encoder.py rename batdetect2/models/{typing.py => types.py} (100%) create mode 100644 batdetect2/postprocess/__init__.py create mode 100644 batdetect2/postprocess/arrays.py create mode 100644 batdetect2/postprocess/config.py create mode 100644 batdetect2/postprocess/non_max_supression.py create mode 100644 batdetect2/postprocess/types.py create mode 100644 batdetect2/targets/__init__.py rename batdetect2/{train => targets}/labels.py (100%) rename batdetect2/{train => targets}/targets.py (98%) create mode 100644 batdetect2/targets/terms.py delete mode 100644 batdetect2/terms.py create mode 100644 batdetect2/train/callbacks.py create mode 100644 tests/test_postprocessing/__init__.py create mode 100644 tests/test_postprocessing/test_arrays.py create mode 100644 tests/test_targets/__init__.py create mode 100644 tests/test_targets/test_terms.py diff --git a/batdetect2/cli/train.py b/batdetect2/cli/train.py index 79e1aef..60fbc45 100644 --- a/batdetect2/cli/train.py +++ b/batdetect2/cli/train.py @@ -8,9 +8,11 @@ from batdetect2.data import load_dataset_from_config from batdetect2.preprocess import ( load_preprocessing_config, ) -from batdetect2.train import ( +from batdetect2.targets import ( load_label_config, load_target_config, +) +from batdetect2.train import ( preprocess_annotations, ) diff --git a/batdetect2/compat/params.py b/batdetect2/compat/params.py index d19710f..2f9f767 100644 --- a/batdetect2/compat/params.py +++ b/batdetect2/compat/params.py @@ -12,10 +12,13 @@ from batdetect2.preprocess import ( STFTConfig, ) from batdetect2.preprocess.spectrogram import get_spectrogram_resolution -from batdetect2.terms import TagInfo -from batdetect2.train.preprocess import ( +from batdetect2.targets import ( HeatmapsConfig, + TagInfo, TargetConfig, +) +from batdetect2.targets.labels import LabelConfig +from batdetect2.train.preprocess import ( TrainPreprocessingConfig, ) @@ -91,11 +94,13 @@ def get_training_preprocessing_config( for value in params["classes_to_ignore"] ], ), - heatmaps=HeatmapsConfig( - position="bottom-left", - time_scale=1 / time_bin_width, - frequency_scale=1 / freq_bin_width, - sigma=params["target_sigma"], + labels=LabelConfig( + heatmaps=HeatmapsConfig( + position="bottom-left", + time_scale=1 / time_bin_width, + frequency_scale=1 / freq_bin_width, + sigma=params["target_sigma"], + ) ), ) diff --git a/batdetect2/data/__init__.py b/batdetect2/data/__init__.py index 512ed87..104dfdb 100644 --- a/batdetect2/data/__init__.py +++ b/batdetect2/data/__init__.py @@ -1,12 +1,18 @@ from batdetect2.data.annotations import ( AnnotatedDataset, + AOEFAnnotations, + BatDetect2FilesAnnotations, + BatDetect2MergedAnnotations, load_annotated_dataset, ) from batdetect2.data.data import load_dataset, load_dataset_from_config from batdetect2.data.types import Dataset __all__ = [ + "AOEFAnnotations", "AnnotatedDataset", + "BatDetect2FilesAnnotations", + "BatDetect2MergedAnnotations", "Dataset", "load_annotated_dataset", "load_dataset", diff --git a/batdetect2/data/annotations.py b/batdetect2/data/annotations.py index a69b79f..a1e6ed2 100644 --- a/batdetect2/data/annotations.py +++ b/batdetect2/data/annotations.py @@ -1,4 +1,3 @@ -import json from pathlib import Path from typing import Literal, Union @@ -32,5 +31,3 @@ AnnotationFormats = Union[ BatDetect2AnnotationFile, AOEFAnnotationFile, ] - - diff --git a/batdetect2/data/annotations/types.py b/batdetect2/data/annotations/types.py index a184b23..188eb3d 100644 --- a/batdetect2/data/annotations/types.py +++ b/batdetect2/data/annotations/types.py @@ -1,11 +1,9 @@ from pathlib import Path -from typing import Literal, Union from batdetect2.configs import BaseConfig __all__ = [ "AnnotatedDataset", - "BatDetect2MergedAnnotations", ] @@ -37,5 +35,3 @@ class AnnotatedDataset(BaseConfig): name: str audio_dir: Path description: str = "" - - diff --git a/batdetect2/evaluate/__init__.py b/batdetect2/evaluate/__init__.py index e69de29..58c0b3f 100644 --- a/batdetect2/evaluate/__init__.py +++ b/batdetect2/evaluate/__init__.py @@ -0,0 +1,9 @@ +from batdetect2.evaluate.evaluate import ( + compute_error_auc, + match_predictions_and_annotations, +) + +__all__ = [ + "compute_error_auc", + "match_predictions_and_annotations", +] diff --git a/batdetect2/evaluate/evaluate.py b/batdetect2/evaluate/evaluate.py index 5e403b8..af8f603 100755 --- a/batdetect2/evaluate/evaluate.py +++ b/batdetect2/evaluate/evaluate.py @@ -1,13 +1,10 @@ from typing import List import numpy as np -import pandas as pd from sklearn.metrics import auc, roc_curve from soundevent import data from soundevent.evaluation import match_geometries -from batdetect2.train.targets import build_target_encoder, get_class_names - def match_predictions_and_annotations( clip_annotation: data.ClipAnnotation, @@ -51,20 +48,13 @@ def match_predictions_and_annotations( return matches -def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame: - ret = [] - - for match in matches: - pass - - def compute_error_auc(op_str, gt, pred, prob): # classification error pred_int = (pred > prob).astype(np.int32) class_acc = (pred_int == gt).mean() * 100.0 # ROC - area under curve - fpr, tpr, thresholds = roc_curve(gt, pred) + fpr, tpr, _ = roc_curve(gt, pred) roc_auc = auc(fpr, tpr) print( @@ -177,7 +167,7 @@ def compute_pre_rec( file_ids.append([pid] * valid_inds.sum()) confidence = np.hstack(confidence) - file_ids = np.hstack(file_ids).astype(np.int) + file_ids = np.hstack(file_ids).astype(int) pred_boxes = np.vstack(pred_boxes) if len(pred_class) > 0: pred_class = np.hstack(pred_class) @@ -197,8 +187,7 @@ def compute_pre_rec( # note, files with the incorrect duration will cause a problem if (gg["start_times"] > file_dur).sum() > 0: - print("Error: file duration incorrect for", gg["id"]) - assert False + raise ValueError(f"Error: file duration incorrect for {gg['id']}") boxes = np.vstack( ( @@ -244,6 +233,8 @@ def compute_pre_rec( gt_id = file_ids[ind] valid_det = False + det_ind = 0 + if gt_boxes[gt_id].shape[0] > 0: # compute overlap valid_det, det_ind = compute_affinity_1d( @@ -273,7 +264,7 @@ def compute_pre_rec( # store threshold values - used for plotting conf_sorted = np.sort(confidence)[::-1][valid_inds] thresholds = np.linspace(0.1, 0.9, 9) - thresholds_inds = np.zeros(len(thresholds), dtype=np.int) + thresholds_inds = np.zeros(len(thresholds), dtype=int) for ii, tt in enumerate(thresholds): thresholds_inds[ii] = np.argmin(conf_sorted > tt) thresholds_inds[thresholds_inds == 0] = -1 @@ -385,7 +376,7 @@ def compute_file_accuracy(gts, preds, num_classes): ).mean(0) best_thresh = np.argmax(acc_per_thresh) best_acc = acc_per_thresh[best_thresh] - pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist() + pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist() res = {} res["num_valid_files"] = len(gt_valid) diff --git a/batdetect2/models/__init__.py b/batdetect2/models/__init__.py index 5a88a09..6d909dd 100644 --- a/batdetect2/models/__init__.py +++ b/batdetect2/models/__init__.py @@ -1,92 +1,26 @@ -from enum import Enum -from typing import Optional, Tuple - -from soundevent.data import PathLike - -from batdetect2.configs import BaseConfig, load_config from batdetect2.models.backbones import ( Net2DFast, Net2DFastNoAttn, Net2DFastNoCoordConv, Net2DPlain, ) +from batdetect2.models.build import build_architecture +from batdetect2.models.config import ModelConfig, ModelType, load_model_config from batdetect2.models.heads import BBoxHead, ClassifierHead -from batdetect2.models.typing import BackboneModel +from batdetect2.models.types import BackboneModel, ModelOutput __all__ = [ "BBoxHead", + "BackboneModel", "ClassifierHead", "ModelConfig", + "ModelOutput", "ModelType", "Net2DFast", "Net2DFastNoAttn", "Net2DFastNoCoordConv", + "Net2DPlain", + "build_architecture", "build_architecture", "load_model_config", ] - - -class ModelType(str, Enum): - Net2DFast = "Net2DFast" - Net2DFastNoAttn = "Net2DFastNoAttn" - Net2DFastNoCoordConv = "Net2DFastNoCoordConv" - Net2DPlain = "Net2DPlain" - - -class ModelConfig(BaseConfig): - name: ModelType = ModelType.Net2DFast - input_height: int = 128 - encoder_channels: Tuple[int, ...] = (1, 32, 64, 128) - bottleneck_channels: int = 256 - decoder_channels: Tuple[int, ...] = (256, 64, 32, 32) - out_channels: int = 32 - - -def load_model_config( - path: PathLike, field: Optional[str] = None -) -> ModelConfig: - return load_config(path, schema=ModelConfig, field=field) - - -def build_architecture( - config: Optional[ModelConfig] = None, -) -> BackboneModel: - config = config or ModelConfig() - - if config.name == ModelType.Net2DFast: - return Net2DFast( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - if config.name == ModelType.Net2DFastNoAttn: - return Net2DFastNoAttn( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - if config.name == ModelType.Net2DFastNoCoordConv: - return Net2DFastNoCoordConv( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - if config.name == ModelType.Net2DPlain: - return Net2DPlain( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - raise ValueError(f"Unknown model type: {config.name}") diff --git a/batdetect2/models/backbones.py b/batdetect2/models/backbones.py index 2e53f3b..5bdf5e4 100644 --- a/batdetect2/models/backbones.py +++ b/batdetect2/models/backbones.py @@ -12,12 +12,13 @@ from batdetect2.models.blocks import ( UpscalingLayer, VerticalConv, ) -from batdetect2.models.typing import BackboneModel +from batdetect2.models.types import BackboneModel __all__ = [ "Net2DFast", "Net2DFastNoAttn", "Net2DFastNoCoordConv", + "Net2DPlain", ] @@ -165,7 +166,6 @@ def pad_adjust( spec: torch.Tensor, factor: int = 32, ) -> Tuple[torch.Tensor, int, int]: - print(spec.shape) h, w = spec.shape[2:] h_pad = -h % factor w_pad = -w % factor diff --git a/batdetect2/models/build.py b/batdetect2/models/build.py new file mode 100644 index 0000000..474002c --- /dev/null +++ b/batdetect2/models/build.py @@ -0,0 +1,58 @@ +from typing import Optional + +from batdetect2.models.backbones import ( + Net2DFast, + Net2DFastNoAttn, + Net2DFastNoCoordConv, + Net2DPlain, +) +from batdetect2.models.config import ModelConfig, ModelType +from batdetect2.models.types import BackboneModel + +__all__ = [ + "build_architecture", +] + + +def build_architecture( + config: Optional[ModelConfig] = None, +) -> BackboneModel: + config = config or ModelConfig() + + if config.name == ModelType.Net2DFast: + return Net2DFast( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + if config.name == ModelType.Net2DFastNoAttn: + return Net2DFastNoAttn( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + if config.name == ModelType.Net2DFastNoCoordConv: + return Net2DFastNoCoordConv( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + if config.name == ModelType.Net2DPlain: + return Net2DPlain( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + raise ValueError(f"Unknown model type: {config.name}") diff --git a/batdetect2/models/config.py b/batdetect2/models/config.py new file mode 100644 index 0000000..1feef1d --- /dev/null +++ b/batdetect2/models/config.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Optional, Tuple + +from soundevent.data import PathLike + +from batdetect2.configs import BaseConfig, load_config + +__all__ = [ + "ModelType", + "ModelConfig", + "load_model_config", +] + + +class ModelType(str, Enum): + Net2DFast = "Net2DFast" + Net2DFastNoAttn = "Net2DFastNoAttn" + Net2DFastNoCoordConv = "Net2DFastNoCoordConv" + Net2DPlain = "Net2DPlain" + + +class ModelConfig(BaseConfig): + name: ModelType = ModelType.Net2DFast + input_height: int = 128 + encoder_channels: Tuple[int, ...] = (1, 32, 64, 128) + bottleneck_channels: int = 256 + decoder_channels: Tuple[int, ...] = (256, 64, 32, 32) + out_channels: int = 32 + + +def load_model_config( + path: PathLike, field: Optional[str] = None +) -> ModelConfig: + return load_config(path, schema=ModelConfig, field=field) diff --git a/batdetect2/models/decoder.py b/batdetect2/models/decoder.py deleted file mode 100644 index 75863c0..0000000 --- a/batdetect2/models/decoder.py +++ /dev/null @@ -1,15 +0,0 @@ -import sys -from typing import Iterable, List, Literal, Sequence - -import torch -from torch import nn - -from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard - -if sys.version_info >= (3, 10): - from itertools import pairwise -else: - - def pairwise(iterable: Sequence) -> Iterable: - for x, y in zip(iterable[:-1], iterable[1:]): - yield x, y diff --git a/batdetect2/models/encoder.py b/batdetect2/models/encoder.py deleted file mode 100644 index 4192e62..0000000 --- a/batdetect2/models/encoder.py +++ /dev/null @@ -1,15 +0,0 @@ -import sys -from typing import Iterable, List, Literal, Sequence - -import torch -from torch import nn - -from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard - -if sys.version_info >= (3, 10): - from itertools import pairwise -else: - - def pairwise(iterable: Sequence) -> Iterable: - for x, y in zip(iterable[:-1], iterable[1:]): - yield x, y diff --git a/batdetect2/models/typing.py b/batdetect2/models/types.py similarity index 100% rename from batdetect2/models/typing.py rename to batdetect2/models/types.py diff --git a/batdetect2/modules.py b/batdetect2/modules.py index d97a8ef..b379771 100644 --- a/batdetect2/modules.py +++ b/batdetect2/modules.py @@ -7,31 +7,27 @@ from pydantic import Field from soundevent import data from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.utils.data import DataLoader from batdetect2.configs import BaseConfig -from batdetect2.evaluate.evaluate import match_predictions_and_annotations from batdetect2.models import ( BBoxHead, ClassifierHead, ModelConfig, + ModelOutput, build_architecture, ) -from batdetect2.models.typing import ModelOutput from batdetect2.post_process import ( PostprocessConfig, postprocess_model_outputs, ) from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip -from batdetect2.train.config import TrainingConfig -from batdetect2.train.dataset import LabeledDataset, TrainExample -from batdetect2.train.losses import compute_loss -from batdetect2.train.targets import ( +from batdetect2.targets import ( TargetConfig, build_decoder, build_target_encoder, get_class_names, ) +from batdetect2.train import TrainExample, TrainingConfig, compute_loss __all__ = [ "DetectorModel", @@ -83,12 +79,9 @@ class DetectorModel(L.LightningModule): replacement_rules=self.config.targets.replace, ) self.decoder = build_decoder(self.config.targets.classes) - - self.validation_predictions = [] - self.example_input_array = torch.randn([1, 1, 128, 512]) - def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore + def forward(self, spec: torch.Tensor) -> ModelOutput: features = self.backbone(spec) detection_probs, classification_probs = self.classifier(features) size_preds = self.bbox(features) @@ -130,27 +123,6 @@ class DetectorModel(L.LightningModule): self.log("val/loss/size", losses.total, logger=True) self.log("val/loss/classification", losses.total, logger=True) - dataloaders = self.trainer.val_dataloaders - assert isinstance(dataloaders, DataLoader) - dataset = dataloaders.dataset - assert isinstance(dataset, LabeledDataset) - clip_annotation = dataset.get_clip_annotation(batch_idx) - - clip_prediction = postprocess_model_outputs( - outputs, - clips=[clip_annotation.clip], - classes=self.class_names, - decoder=self.decoder, - config=self.config.postprocessing, - )[0] - - matches = match_predictions_and_annotations( - clip_annotation, - clip_prediction, - ) - - self.validation_predictions.extend(matches) - def on_validation_epoch_end(self) -> None: self.validation_predictions.clear() diff --git a/batdetect2/post_process.py b/batdetect2/post_process.py index 85bb546..f266225 100644 --- a/batdetect2/post_process.py +++ b/batdetect2/post_process.py @@ -9,7 +9,7 @@ from soundevent import data from torch import nn from batdetect2.configs import BaseConfig, load_config -from batdetect2.models.typing import ModelOutput +from batdetect2.models.types import ModelOutput __all__ = [ "PostprocessConfig", diff --git a/batdetect2/postprocess/__init__.py b/batdetect2/postprocess/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/batdetect2/postprocess/arrays.py b/batdetect2/postprocess/arrays.py new file mode 100644 index 0000000..f657e28 --- /dev/null +++ b/batdetect2/postprocess/arrays.py @@ -0,0 +1,73 @@ +import numpy as np +import xarray as xr +from soundevent.arrays import Dimensions + +from batdetect2.models import ModelOutput +from batdetect2.preprocess import MAX_FREQ, MIN_FREQ + + +def to_xarray( + output: ModelOutput, + start_time: float, + end_time: float, + class_names: list[str], + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, +): + detection = output.detection_probs + size = output.size_preds + classes = output.class_probs + features = output.features + + if len(detection.shape) == 4: + if detection.shape[0] != 1: + raise ValueError( + "Expected a non-batched output or a batch of size 1, instead " + f"got an input of shape {detection.shape}" + ) + + detection = detection.squeeze(dim=0) + size = size.squeeze(dim=0) + classes = classes.squeeze(dim=0) + features = features.squeeze(dim=0) + + _, width, height = detection.shape + + times = np.linspace(start_time, end_time, width, endpoint=False) + freqs = np.linspace(min_freq, max_freq, height, endpoint=False) + + if classes.shape[0] != len(class_names): + raise ValueError( + f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })" + ) + + return xr.Dataset( + data_vars={ + "detection": ( + [Dimensions.time.value, Dimensions.frequency.value], + detection.squeeze(dim=0).detach().numpy(), + ), + "size": ( + [ + "dimension", + Dimensions.time.value, + Dimensions.frequency.value, + ], + detection.detach().numpy(), + ), + "classes": ( + [ + "category", + Dimensions.time.value, + Dimensions.frequency.value, + ], + classes.detach().numpy(), + ), + }, + coords={ + Dimensions.time.value: times, + Dimensions.frequency.value: freqs, + "dimension": ["width", "height"], + "category": class_names, + }, + ) diff --git a/batdetect2/postprocess/config.py b/batdetect2/postprocess/config.py new file mode 100644 index 0000000..3c4bada --- /dev/null +++ b/batdetect2/postprocess/config.py @@ -0,0 +1,32 @@ +from typing import Optional + +from pydantic import Field +from soundevent import data + +from batdetect2.configs import BaseConfig, load_config + +__all__ = [ + "PostprocessConfig", + "load_postprocess_config", +] + +NMS_KERNEL_SIZE = 9 +DETECTION_THRESHOLD = 0.01 +TOP_K_PER_SEC = 200 + + +class PostprocessConfig(BaseConfig): + """Configuration for postprocessing model outputs.""" + + nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) + detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0) + min_freq: int = Field(default=10000, gt=0) + max_freq: int = Field(default=120000, gt=0) + top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) + + +def load_postprocess_config( + path: data.PathLike, + field: Optional[str] = None, +) -> PostprocessConfig: + return load_config(path, schema=PostprocessConfig, field=field) diff --git a/batdetect2/postprocess/non_max_supression.py b/batdetect2/postprocess/non_max_supression.py new file mode 100644 index 0000000..1893a12 --- /dev/null +++ b/batdetect2/postprocess/non_max_supression.py @@ -0,0 +1,50 @@ +from typing import Tuple, Union + +import torch + +NMS_KERNEL_SIZE = 9 + + +def non_max_suppression( + tensor: torch.Tensor, + kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, +) -> torch.Tensor: + """Run non-maximum suppression on a tensor. + + This function removes values from the input tensor that are not local + maxima in the neighborhood of the given kernel size. + + All non-maximum values are set to zero. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor. + kernel_size : Union[int, Tuple[int, int]], optional + Size of the neighborhood to consider for non-maximum suppression. + If an integer is given, the neighborhood will be a square of the + given size. If a tuple is given, the neighborhood will be a + rectangle with the given height and width. + + Returns + ------- + torch.Tensor + Tensor with non-maximum suppressed values. + """ + if isinstance(kernel_size, int): + kernel_size_h = kernel_size + kernel_size_w = kernel_size + else: + kernel_size_h, kernel_size_w = kernel_size + + pad_h = (kernel_size_h - 1) // 2 + pad_w = (kernel_size_w - 1) // 2 + + hmax = torch.nn.functional.max_pool2d( + tensor, + (kernel_size_h, kernel_size_w), + stride=1, + padding=(pad_h, pad_w), + ) + keep = (hmax == tensor).float() + return tensor * keep diff --git a/batdetect2/postprocess/types.py b/batdetect2/postprocess/types.py new file mode 100644 index 0000000..e3ea131 --- /dev/null +++ b/batdetect2/postprocess/types.py @@ -0,0 +1,17 @@ +from typing import Dict, NamedTuple + +import numpy as np + +__all__ = [ + "BatDetect2Prediction", +] + + +class BatDetect2Prediction(NamedTuple): + start_time: float + end_time: float + low_freq: float + high_freq: float + detection_score: float + class_scores: Dict[str, float] + features: np.ndarray diff --git a/batdetect2/preprocess/__init__.py b/batdetect2/preprocess/__init__.py index fcdc63b..ea4b0a4 100644 --- a/batdetect2/preprocess/__init__.py +++ b/batdetect2/preprocess/__init__.py @@ -15,6 +15,8 @@ from batdetect2.preprocess.config import ( load_preprocessing_config, ) from batdetect2.preprocess.spectrogram import ( + MAX_FREQ, + MIN_FREQ, AmplitudeScaleConfig, FrequencyConfig, LogScaleConfig, @@ -31,6 +33,8 @@ __all__ = [ "AudioConfig", "FrequencyConfig", "LogScaleConfig", + "MAX_FREQ", + "MIN_FREQ", "PcenScaleConfig", "PreprocessingConfig", "ResampleConfig", diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py index e5cfab2..66820d2 100644 --- a/batdetect2/preprocess/audio.py +++ b/batdetect2/preprocess/audio.py @@ -31,7 +31,7 @@ def load_file_audio( path: data.PathLike, config: Optional[AudioConfig] = None, audio_dir: Optional[data.PathLike] = None, - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: recording = data.Recording.from_file(path) return load_recording_audio( @@ -46,7 +46,7 @@ def load_recording_audio( recording: data.Recording, config: Optional[AudioConfig] = None, audio_dir: Optional[data.PathLike] = None, - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: clip = data.Clip( recording=recording, @@ -65,7 +65,7 @@ def load_clip_audio( clip: data.Clip, config: Optional[AudioConfig] = None, audio_dir: Optional[data.PathLike] = None, - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: config = config or AudioConfig() @@ -122,7 +122,7 @@ def resample_audio( wav: xr.DataArray, samplerate: int = TARGET_SAMPLERATE_HZ, mode: str = "poly", - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: if "time" not in wav.dims: raise ValueError("Audio must have a time dimension") diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index bf228f5..7f7b2d1 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -11,6 +11,20 @@ from soundevent.arrays import operations as ops from batdetect2.configs import BaseConfig +__all__ = [ + "STFTConfig", + "FrequencyConfig", + "LogScaleConfig", + "PcenScaleConfig", + "AmplitudeScaleConfig", + "Scales", + "SpectrogramConfig", + "compute_spectrogram", +] + +MIN_FREQ = 10_000 +MAX_FREQ = 120_000 + class STFTConfig(BaseConfig): window_duration: float = Field(default=0.002, gt=0) @@ -70,7 +84,7 @@ class SpectrogramConfig(BaseConfig): def compute_spectrogram( wav: xr.DataArray, config: Optional[SpectrogramConfig] = None, - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: config = config or SpectrogramConfig() @@ -124,7 +138,7 @@ def stft( window_duration: float, window_overlap: float, window_fn: str = "hann", - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: start_time, end_time = arrays.get_dim_range(wave, dim="time") step = arrays.get_dim_step(wave, dim="time") @@ -190,7 +204,7 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray: def scale_spectrogram( spec: xr.DataArray, scale: Scales, - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: if scale.name == "log": return scale_log(spec, dtype=dtype) @@ -230,7 +244,7 @@ def scale_pcen( def scale_log( spec: xr.DataArray, - dtype: DTypeLike = np.float32, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: samplerate = spec.attrs["original_samplerate"] nfft = spec.attrs["nfft"] diff --git a/batdetect2/targets/__init__.py b/batdetect2/targets/__init__.py new file mode 100644 index 0000000..fddccda --- /dev/null +++ b/batdetect2/targets/__init__.py @@ -0,0 +1,49 @@ +"""Module that helps define what the training targets are. + +The goal of this module is to configure how raw sound event annotations (tags) +are processed to determine which events are relevant for training and what +specific class label each relevant event should receive. Also, define how +predicted class labels map back to the original tag system for interpretation. +""" + +from batdetect2.targets.labels import ( + HeatmapsConfig, + LabelConfig, + generate_heatmaps, + load_label_config, +) +from batdetect2.targets.targets import ( + TargetConfig, + build_decoder, + build_sound_event_filter, + build_target_encoder, + filter_sound_event, + get_class_names, + load_target_config, +) +from batdetect2.targets.terms import ( + TagInfo, + TermInfo, + call_type, + get_tag_from_info, + individual, +) + +__all__ = [ + "HeatmapsConfig", + "LabelConfig", + "TagInfo", + "TargetConfig", + "TermInfo", + "build_decoder", + "build_sound_event_filter", + "build_target_encoder", + "call_type", + "filter_sound_event", + "generate_heatmaps", + "get_class_names", + "get_tag_from_info", + "individual", + "load_label_config", + "load_target_config", +] diff --git a/batdetect2/train/labels.py b/batdetect2/targets/labels.py similarity index 100% rename from batdetect2/train/labels.py rename to batdetect2/targets/labels.py diff --git a/batdetect2/train/targets.py b/batdetect2/targets/targets.py similarity index 98% rename from batdetect2/train/targets.py rename to batdetect2/targets/targets.py index 0b622fd..e04005e 100644 --- a/batdetect2/train/targets.py +++ b/batdetect2/targets/targets.py @@ -7,7 +7,7 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config -from batdetect2.terms import TagInfo, get_tag_from_info +from batdetect2.targets.terms import TagInfo, get_tag_from_info __all__ = [ "TargetConfig", diff --git a/batdetect2/targets/terms.py b/batdetect2/targets/terms.py new file mode 100644 index 0000000..f9b0a9c --- /dev/null +++ b/batdetect2/targets/terms.py @@ -0,0 +1,370 @@ +"""Manages the vocabulary (Terms and Tags) for defining training targets. + +This module provides the necessary tools to declare, register, and manage the +set of `soundevent.data.Term` objects used throughout the `batdetect2.targets` +sub-package. It establishes a consistent vocabulary for filtering, +transforming, and classifying sound events based on their annotations (Tags). + +The core component is the `TermRegistry`, which maps unique string keys +(aliases) to specific `Term` definitions. This allows users to refer to complex +terms using simple, consistent keys in configuration files and code. + +Terms can be pre-defined, loaded from the `soundevent.terms` library, defined +programmatically, or loaded from external configuration files (e.g., YAML). +""" + +from inspect import getmembers +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field +from soundevent import data, terms + +from batdetect2.configs import load_config + +__all__ = [ + "call_type", + "individual", + "get_tag_from_info", + "TermInfo", + "TagInfo", +] + +# The default key used to reference the 'generic_class' term. +# Often used implicitly when defining classification targets. +GENERIC_CLASS_KEY = "class" + + +call_type = data.Term( + name="soundevent:call_type", + label="Call Type", + definition=( + "A broad categorization of animal vocalizations based on their " + "intended function or purpose (e.g., social, distress, mating, " + "territorial, echolocation)." + ), +) +"""Term representing the broad functional category of a vocalization.""" + +individual = data.Term( + name="soundevent:individual", + label="Individual", + definition=( + "An id for an individual animal. In the context of bioacoustic " + "annotation, this term is used to label vocalizations that are " + "attributed to a specific individual." + ), +) +"""Term used for tags identifying a specific individual animal.""" + +generic_class = data.Term( + name="soundevent:class", + label="Class", + definition=( + "A generic term representing the name of a class within a " + "classification model. Its specific meaning is determined by " + "the model's application." + ), +) +"""Generic term representing a classification model's output class label.""" + + +class TermRegistry: + """Manages a registry mapping unique keys to Term definitions. + + This class acts as the central repository for the vocabulary of terms + used within the target definition process. It allows registering terms + with simple string keys and retrieving them consistently. + """ + + def __init__(self, terms: Optional[Dict[str, data.Term]] = None): + """Initializes the TermRegistry. + + Args: + terms: An optional dictionary of initial key-to-Term mappings + to populate the registry with. Defaults to an empty registry. + """ + self._terms = terms or {} + + def add_term(self, key: str, term: data.Term) -> None: + """Adds a Term object to the registry with the specified key. + + Args: + key: The unique string key to associate with the term. + term: The soundevent.data.Term object to register. + + Raises: + KeyError: If a term with the provided key already exists in the + registry. + """ + if key in self._terms: + raise KeyError("A term with the provided key already exists.") + + self._terms[key] = term + + def get_term(self, key: str) -> data.Term: + """Retrieves a registered term by its unique key. + + Args: + key: The unique string key of the term to retrieve. + + Returns: + The corresponding soundevent.data.Term object. + + Raises: + KeyError: If no term with the specified key is found, with a + helpful message suggesting listing available keys. + """ + try: + return self._terms[key] + except KeyError as err: + raise KeyError( + "No term found for key " + f"'{key}'. Ensure it is registered or loaded. " + "Use `get_term_keys()` to list available terms." + ) from err + + def add_custom_term( + self, + key: str, + name: Optional[str] = None, + uri: Optional[str] = None, + label: Optional[str] = None, + definition: Optional[str] = None, + ) -> data.Term: + """Creates a new Term from attributes and adds it to the registry. + + This is useful for defining terms directly in code or when loading + from configuration files where only attributes are provided. + + If optional fields (`name`, `label`, `definition`) are not provided, + reasonable defaults are used (`key` for name/label, "Unknown" for + definition). + + Args: + key: The unique string key for the new term. + name: The name for the new term (defaults to `key`). + uri: The URI for the new term (optional). + label: The display label for the new term (defaults to `key`). + definition: The definition for the new term (defaults to "Unknown"). + + Returns: + The newly created and registered soundevent.data.Term object. + + Raises: + KeyError: If a term with the provided key already exists. + """ + term = data.Term( + name=name or key, + label=label or key, + uri=uri, + definition=definition or "Unknown", + ) + self.add_term(key, term) + return term + + def get_keys(self) -> List[str]: + """Returns a list of all keys currently registered. + + Returns: + A list of strings representing the keys of all registered terms. + """ + return list(self._terms.keys()) + + +registry = TermRegistry( + terms=dict( + [ + *getmembers(terms, lambda x: isinstance(x, data.Term)), + ("call_type", call_type), + ("individual", individual), + (GENERIC_CLASS_KEY, generic_class), + ] + ) +) +"""The default, globally accessible TermRegistry instance. + +It is pre-populated with standard terms from `soundevent.terms` and common +terms defined in this module (`call_type`, `individual`, `generic_class`). +Functions in this module use this registry by default unless another instance +is explicitly passed. +""" + + +def get_term_from_key( + key: str, + term_registry: TermRegistry = registry, +) -> data.Term: + """Convenience function to retrieve a term by key from a registry. + + Uses the global default registry unless a specific `term_registry` + instance is provided. + + Args: + key: The unique key of the term to retrieve. + term_registry: The TermRegistry instance to search in. Defaults to + the global `registry`. + + Returns: + The corresponding soundevent.data.Term object. + + Raises: + KeyError: If the key is not found in the specified registry. + """ + return term_registry.get_term(key) + + +def get_term_keys(term_registry: TermRegistry = registry) -> List[str]: + """Convenience function to get all registered keys from a registry. + + Uses the global default registry unless a specific `term_registry` + instance is provided. + + Args: + term_registry: The TermRegistry instance to query. Defaults to + the global `registry`. + + Returns: + A list of strings representing the keys of all registered terms. + """ + return term_registry.get_keys() + + +class TagInfo(BaseModel): + """Represents information needed to define a specific Tag. + + This model is typically used in configuration files (e.g., YAML) to + specify tags used for filtering, target class definition, or associating + tags with output classes. It links a tag value to a term definition + via the term's registry key. + + Attributes: + value: The value of the tag (e.g., "Myotis myotis", "Echolocation"). + key: The key (alias) of the term associated with this tag, as + registered in the TermRegistry. Defaults to "class", implying + it represents a classification target label by default. + """ + + value: str + key: str = GENERIC_CLASS_KEY + + +def get_tag_from_info( + tag_info: TagInfo, + term_registry: TermRegistry = registry, +) -> data.Tag: + """Creates a soundevent.data.Tag object from TagInfo data. + + Looks up the term using the key in the provided `tag_info` from the + specified registry and constructs a Tag object. + + Args: + tag_info: The TagInfo object containing the value and term key. + term_registry: The TermRegistry instance to use for term lookup. + Defaults to the global `registry`. + + Returns: + A soundevent.data.Tag object corresponding to the input info. + + Raises: + KeyError: If the term key specified in `tag_info.key` is not found + in the registry. + """ + term = get_term_from_key(tag_info.key, term_registry=term_registry) + return data.Tag(term=term, value=tag_info.value) + + +class TermInfo(BaseModel): + """Represents the definition of a Term within a configuration file. + + This model allows users to define custom terms directly in configuration + files (e.g., YAML) which can then be loaded into the TermRegistry. + It mirrors the parameters of `TermRegistry.add_custom_term`. + + Attributes: + key: The unique key (alias) that will be used to register and + reference this term. + label: The optional display label for the term. Defaults to `key` + if not provided during registration. + name: The optional formal name for the term. Defaults to `key` + if not provided during registration. + uri: The optional URI identifying the term (e.g., from a standard + vocabulary). + definition: The optional textual definition of the term. Defaults to + "Unknown" if not provided during registration. + """ + + key: str + label: Optional[str] = None + name: Optional[str] = None + uri: Optional[str] = None + definition: Optional[str] = None + + +class TermConfig(BaseModel): + """Pydantic schema for loading a list of term definitions from config. + + This model typically corresponds to a section in a configuration file + (e.g., YAML) containing a list of term definitions to be registered. + + Example YAML structure: + ```yaml + terms: + - key: species + uri: dwc:scientificName + label: Scientific Name + - key: my_custom_term + name: My Custom Term + definition: Describes a specific project attribute. + # ... more TermInfo definitions + ``` + + Attributes: + terms: A list of TermInfo objects, each defining a term to be + registered. Defaults to an empty list. + """ + + terms: List[TermInfo] = Field(default_factory=list) + + +def load_terms_from_config( + path: data.PathLike, + field: Optional[str] = None, + term_registry: TermRegistry = registry, +) -> Dict[str, data.Term]: + """Loads term definitions from a configuration file and registers them. + + Parses a configuration file (e.g., YAML) using the TermConfig schema, + extracts the list of TermInfo definitions, and adds each one as a + custom term to the specified TermRegistry instance. + + Args: + path: The path to the configuration file. + field: Optional key indicating a specific section within the config + file where the 'terms' list is located. If None, expects the + list directly at the top level or within a structure matching + TermConfig schema. + term_registry: The TermRegistry instance to add the loaded terms to. + Defaults to the global `registry`. + + Returns: + A dictionary mapping the keys of the newly added terms to their + corresponding Term objects. + + Raises: + FileNotFoundError: If the config file path does not exist. + ValidationError: If the config file structure does not match the + TermConfig schema. + KeyError: If a term key loaded from the config conflicts with a key + already present in the registry. + """ + data = load_config(path, schema=TermConfig, field=field) + return { + info.key: term_registry.add_custom_term( + info.key, + name=info.name, + uri=info.uri, + label=info.label, + definition=info.definition, + ) + for info in data.terms + } diff --git a/batdetect2/terms.py b/batdetect2/terms.py deleted file mode 100644 index e60e3f2..0000000 --- a/batdetect2/terms.py +++ /dev/null @@ -1,88 +0,0 @@ -from inspect import getmembers -from typing import Optional - -from pydantic import BaseModel -from soundevent import data, terms - -__all__ = [ - "call_type", - "individual", - "get_term_from_info", - "get_tag_from_info", - "TermInfo", - "TagInfo", -] - - -class TermInfo(BaseModel): - label: Optional[str] - name: Optional[str] - uri: Optional[str] - - -class TagInfo(BaseModel): - value: str - term: Optional[TermInfo] = None - key: Optional[str] = None - label: Optional[str] = None - - -call_type = data.Term( - name="soundevent:call_type", - label="Call Type", - definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).", -) - -individual = data.Term( - name="soundevent:individual", - label="Individual", - definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.", -) - - -ALL_TERMS = [ - *getmembers(terms, lambda x: isinstance(x, data.Term)), - call_type, - individual, -] - - -def get_term_from_info(term_info: TermInfo) -> data.Term: - for term in ALL_TERMS: - if term_info.name and term_info.name == term.name: - return term - - if term_info.label and term_info.label == term.label: - return term - - if term_info.uri and term_info.uri == term.uri: - return term - - if term_info.name is None: - if term_info.label is None: - raise ValueError("At least one of name or label must be provided.") - - term_info.name = ( - f"soundevent:{term_info.label.lower().replace(' ', '_')}" - ) - - if term_info.label is None: - term_info.label = term_info.name - - return data.Term( - name=term_info.name, - label=term_info.label, - uri=term_info.uri, - definition="Unknown", - ) - - -def get_tag_from_info(tag_info: TagInfo) -> data.Tag: - if tag_info.term: - term = get_term_from_info(tag_info.term) - elif tag_info.key: - term = data.term_from_key(tag_info.key) - else: - raise ValueError("Either term or key must be provided in tag info.") - - return data.Tag(term=term, value=tag_info.value) diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index 4be16c7..ab6c627 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -17,37 +17,26 @@ from batdetect2.train.dataset import ( TrainExample, get_preprocessed_files, ) -from batdetect2.train.labels import LabelConfig, load_label_config +from batdetect2.train.losses import compute_loss from batdetect2.train.preprocess import ( generate_train_example, preprocess_annotations, ) -from batdetect2.train.targets import ( - TagInfo, - TargetConfig, - build_target_encoder, - load_target_config, -) from batdetect2.train.train import TrainerConfig, load_trainer_config, train __all__ = [ "AugmentationsConfig", - "LabelConfig", "LabeledDataset", "SubclipConfig", - "TagInfo", - "TargetConfig", "TrainExample", "TrainerConfig", "TrainingConfig", "add_echo", "augment_example", - "build_target_encoder", + "compute_loss", "generate_train_example", "get_preprocessed_files", "load_agumentation_config", - "load_label_config", - "load_target_config", "load_train_config", "load_trainer_config", "mask_frequency", diff --git a/batdetect2/train/callbacks.py b/batdetect2/train/callbacks.py new file mode 100644 index 0000000..b6bb05e --- /dev/null +++ b/batdetect2/train/callbacks.py @@ -0,0 +1,55 @@ +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from torch.utils.data import DataLoader + +from batdetect2.evaluate import match_predictions_and_annotations +from batdetect2.post_process import postprocess_model_outputs +from batdetect2.train.dataset import LabeledDataset, TrainExample +from batdetect2.types import ModelOutput + + +class ValidationMetrics(Callback): + def __init__(self): + super().__init__() + self.predictions = [] + + def on_validation_epoch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + ) -> None: + self.predictions = [] + return super().on_validation_epoch_start(trainer, pl_module) + + def on_validation_batch_end( # type: ignore + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: ModelOutput, + batch: TrainExample, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + dataloaders = trainer.val_dataloaders + assert isinstance(dataloaders, DataLoader) + dataset = dataloaders.dataset + assert isinstance(dataset, LabeledDataset) + clip_annotation = dataset.get_clip_annotation(batch_idx) + + clip_prediction = postprocess_model_outputs( + outputs, + clips=[clip_annotation.clip], + classes=self.class_names, + decoder=self.decoder, + config=self.config.postprocessing, + )[0] + + matches = match_predictions_and_annotations( + clip_annotation, + clip_prediction, + ) + + self.validation_predictions.extend(matches) + return super().on_validation_batch_end( + trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ) diff --git a/batdetect2/train/legacy/train.py b/batdetect2/train/legacy/train.py index a59b13a..29fed77 100644 --- a/batdetect2/train/legacy/train.py +++ b/batdetect2/train/legacy/train.py @@ -6,7 +6,7 @@ from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader -from batdetect2.models.typing import DetectionModel +from batdetect2.models.types import DetectionModel from batdetect2.train.dataset import LabeledDataset diff --git a/batdetect2/train/losses.py b/batdetect2/train/losses.py index 74739d1..27a132c 100644 --- a/batdetect2/train/losses.py +++ b/batdetect2/train/losses.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from pydantic import Field from batdetect2.configs import BaseConfig -from batdetect2.models.typing import ModelOutput +from batdetect2.models.types import ModelOutput from batdetect2.train.dataset import TrainExample __all__ = [ diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 8b9da80..3a67135 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -17,11 +17,12 @@ from batdetect2.preprocess import ( compute_spectrogram, load_clip_audio, ) -from batdetect2.train.labels import LabelConfig, generate_heatmaps -from batdetect2.train.targets import ( +from batdetect2.targets import ( + LabelConfig, TargetConfig, - build_target_encoder, build_sound_event_filter, + build_target_encoder, + generate_heatmaps, get_class_names, ) diff --git a/pyproject.toml b/pyproject.toml index b52300f..0247083 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,15 @@ convention = "numpy" [tool.pyright] include = ["batdetect2", "tests"] -venvPath = "." -venv = ".venv" pythonVersion = "3.9" pythonPlatform = "All" +exclude = [ + "batdetect2/detector/", + "batdetect2/finetune", + "batdetect2/utils", + "batdetect2/plotting", + "batdetect2/plot", + "batdetect2/api", + "batdetect2/evaluate/legacy", + "batdetect2/train/legacy", +] diff --git a/tests/conftest.py b/tests/conftest.py index f34e8fd..f98459a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,18 @@ import uuid from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, Iterable, List, Optional import numpy as np import pytest import soundfile as sf -from soundevent import data +from soundevent import data, terms + +from batdetect2.targets import ( + TargetConfig, + build_target_encoder, + call_type, + get_class_names, +) @pytest.fixture @@ -107,3 +114,102 @@ def recording_factory(wav_factory: Callable[..., Path]): ) return _recording_factory + + +@pytest.fixture +def recording( + recording_factory: Callable[..., data.Recording], +) -> data.Recording: + return recording_factory() + + +@pytest.fixture +def clip(recording: data.Recording) -> data.Clip: + return data.Clip(recording=recording, start_time=0, end_time=0.5) + + +@pytest.fixture +def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation: + return data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]), + recording=recording, + ), + tags=[ + data.Tag(term=terms.scientific_name, value="Myotis myotis"), + data.Tag(term=call_type, value="Echolocation"), + ], + ) + + +@pytest.fixture +def generic_call(recording: data.Recording) -> data.SoundEventAnnotation: + return data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.BoundingBox( + coordinates=[0.34, 35_000, 0.348, 62_000] + ), + recording=recording, + ), + tags=[ + data.Tag(term=terms.order, value="Chiroptera"), + data.Tag(term=call_type, value="Echolocation"), + ], + ) + + +@pytest.fixture +def non_relevant_sound_event( + recording: data.Recording, +) -> data.SoundEventAnnotation: + return data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.BoundingBox( + coordinates=[0.22, 50_000, 0.24, 58_000] + ), + recording=recording, + ), + tags=[ + data.Tag( + term=terms.scientific_name, + value="Muscardinus avellanarius", + ), + ], + ) + + +@pytest.fixture +def clip_annotation( + clip: data.Clip, + echolocation_call: data.SoundEventAnnotation, + generic_call: data.SoundEventAnnotation, + non_relevant_sound_event: data.SoundEventAnnotation, +) -> data.ClipAnnotation: + return data.ClipAnnotation( + clip=clip, + sound_events=[ + echolocation_call, + generic_call, + non_relevant_sound_event, + ], + ) + + +@pytest.fixture +def target_config() -> TargetConfig: + return TargetConfig() + + +@pytest.fixture +def class_names(target_config: TargetConfig) -> List[str]: + return get_class_names(target_config.classes) + + +@pytest.fixture +def encoder( + target_config: TargetConfig, +) -> Callable[[Iterable[data.Tag]], Optional[str]]: + return build_target_encoder( + classes=target_config.classes, + replacement_rules=target_config.replace, + ) diff --git a/tests/test_data/test_batdetect.py b/tests/test_data/test_batdetect.py index e99505c..baaf907 100644 --- a/tests/test_data/test_batdetect.py +++ b/tests/test_data/test_batdetect.py @@ -4,7 +4,7 @@ from pathlib import Path from soundevent import data -from batdetect2.compat.data import load_annotation_project_from_dir +from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset ROOT_DIR = Path(__file__).parent.parent.parent @@ -12,8 +12,14 @@ ROOT_DIR = Path(__file__).parent.parent.parent def test_load_example_annotation_project(): path = ROOT_DIR / "example_data" / "anns" audio_dir = ROOT_DIR / "example_data" / "audio" - project = load_annotation_project_from_dir(path, audio_dir=audio_dir) + project = load_annotated_dataset( + BatDetect2FilesAnnotations( + name="test", + audio_dir=audio_dir, + annotations_dir=path, + ) + ) assert isinstance(project, data.AnnotationProject) - assert project.name == str(path) + assert project.name == "test" assert len(project.clip_annotations) == 3 assert len(project.tasks) == 3 diff --git a/tests/test_migration/test_training.py b/tests/test_migration/test_training.py index fa46a3e..e146c0a 100644 --- a/tests/test_migration/test_training.py +++ b/tests/test_migration/test_training.py @@ -5,8 +5,8 @@ from typing import List import numpy as np import pytest -from batdetect2.compat.data import load_annotation_project_from_dir from batdetect2.compat.params import get_training_preprocessing_config +from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset from batdetect2.train.preprocess import generate_train_example @@ -26,6 +26,8 @@ def test_can_generate_similar_training_inputs( old_parameters = json.loads((regression_dir / "params.json").read_text()) config = get_training_preprocessing_config(old_parameters) + assert config is not None + for audio_file in example_audio_files: example_file = regression_dir / f"{audio_file.name}.npz" @@ -36,9 +38,12 @@ def test_can_generate_similar_training_inputs( size_mask = dataset["size_mask"] class_mask = dataset["class_mask"] - project = load_annotation_project_from_dir( - example_anns_dir, - audio_dir=example_audio_dir, + project = load_annotated_dataset( + BatDetect2FilesAnnotations( + name="test", + annotations_dir=example_anns_dir, + audio_dir=example_audio_dir, + ) ) clip_annotation = next( @@ -47,7 +52,12 @@ def test_can_generate_similar_training_inputs( if ann.clip.recording.path == audio_file ) - new_dataset = generate_train_example(clip_annotation, config) + new_dataset = generate_train_example( + clip_annotation, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) new_spec = new_dataset["spectrogram"].values new_detection_mask = new_dataset["detection"].values new_size_mask = new_dataset["size"].values diff --git a/tests/test_postprocessing/__init__.py b/tests/test_postprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_postprocessing/test_arrays.py b/tests/test_postprocessing/test_arrays.py new file mode 100644 index 0000000..c94f64e --- /dev/null +++ b/tests/test_postprocessing/test_arrays.py @@ -0,0 +1,43 @@ +from typing import List + +import numpy as np +import torch +import xarray as xr +from soundevent import data + +from batdetect2.modules import DetectorModel +from batdetect2.postprocess.arrays import to_xarray +from batdetect2.preprocess import preprocess_audio_clip + + +def test_this(clip: data.Clip, class_names: List[str]): + spec = xr.DataArray( + data=np.random.rand(100, 100), + dims=["time", "frequency"], + coords={ + "time": np.linspace(0, 100, 100, endpoint=False), + "frequency": np.linspace(0, 100, 100, endpoint=False), + }, + ) + + model = DetectorModel() + + spec = preprocess_audio_clip( + clip, + config=model.config.preprocessing, + ) + + tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0) + + outputs = model(tensor) + + arrays = to_xarray( + outputs, + start_time=clip.start_time, + end_time=clip.end_time, + class_names=class_names, + ) + + print(arrays) + + assert False diff --git a/tests/test_targets/__init__.py b/tests/test_targets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_targets/test_terms.py b/tests/test_targets/test_terms.py new file mode 100644 index 0000000..a717012 --- /dev/null +++ b/tests/test_targets/test_terms.py @@ -0,0 +1,175 @@ +import pytest +import yaml +from soundevent import data + +from batdetect2.targets import terms +from batdetect2.targets.terms import ( + TagInfo, + TermRegistry, + load_terms_from_config, +) + + +def test_term_registry_initialization(): + registry = TermRegistry() + assert registry._terms == {} + + initial_terms = { + "test_term": data.Term(name="test", label="Test", definition="test") + } + registry = TermRegistry(terms=initial_terms) + assert registry._terms == initial_terms + + +def test_term_registry_add_term(): + registry = TermRegistry() + term = data.Term(name="test", label="Test", definition="test") + registry.add_term("test_key", term) + assert registry._terms["test_key"] == term + + +def test_term_registry_get_term(): + registry = TermRegistry() + term = data.Term(name="test", label="Test", definition="test") + registry.add_term("test_key", term) + retrieved_term = registry.get_term("test_key") + assert retrieved_term == term + + +def test_term_registry_add_custom_term(): + registry = TermRegistry() + term = registry.add_custom_term( + "custom_key", name="custom", label="Custom", definition="A custom term" + ) + assert registry._terms["custom_key"] == term + assert term.name == "custom" + assert term.label == "Custom" + assert term.definition == "A custom term" + + +def test_term_registry_add_duplicate_term(): + registry = TermRegistry() + term = data.Term(name="test", label="Test", definition="test") + registry.add_term("test_key", term) + with pytest.raises(KeyError): + registry.add_term("test_key", term) + + +def test_term_registry_get_term_not_found(): + registry = TermRegistry() + with pytest.raises(KeyError): + registry.get_term("non_existent_key") + + +def test_term_registry_get_keys(): + registry = TermRegistry() + term1 = data.Term(name="test1", label="Test1", definition="test") + term2 = data.Term(name="test2", label="Test2", definition="test") + registry.add_term("key1", term1) + registry.add_term("key2", term2) + keys = registry.get_keys() + assert set(keys) == {"key1", "key2"} + + +def test_get_term_from_key(): + term = terms.get_term_from_key("call_type") + assert term == terms.call_type + + custom_registry = TermRegistry() + custom_term = data.Term(name="custom", label="Custom", definition="test") + custom_registry.add_term("custom_key", custom_term) + term = terms.get_term_from_key("custom_key", term_registry=custom_registry) + assert term == custom_term + + +def test_get_term_keys(): + keys = terms.get_term_keys() + assert "call_type" in keys + assert "individual" in keys + assert terms.GENERIC_CLASS_KEY in keys + + custom_registry = TermRegistry() + custom_term = data.Term(name="custom", label="Custom", definition="test") + custom_registry.add_term("custom_key", custom_term) + keys = terms.get_term_keys(term_registry=custom_registry) + assert "custom_key" in keys + + +def test_tag_info_and_get_tag_from_info(): + tag_info = TagInfo(value="Myotis myotis", key="call_type") + tag = terms.get_tag_from_info(tag_info) + assert tag.value == "Myotis myotis" + assert tag.term == terms.call_type + + +def test_get_tag_from_info_key_not_found(): + tag_info = TagInfo(value="test", key="non_existent_key") + with pytest.raises(KeyError): + terms.get_tag_from_info(tag_info) + + +def test_load_terms_from_config(tmp_path): + config_data = { + "terms": [ + { + "key": "species", + "name": "dwc:scientificName", + "label": "Scientific Name", + }, + { + "key": "my_custom_term", + "name": "soundevent:custom_term", + "definition": "Describes a specific project attribute", + }, + ] + } + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + loaded_terms = load_terms_from_config(config_file) + assert "species" in loaded_terms + assert "my_custom_term" in loaded_terms + assert loaded_terms["species"].name == "dwc:scientificName" + assert loaded_terms["my_custom_term"].name == "soundevent:custom_term" + + +def test_load_terms_from_config_file_not_found(): + with pytest.raises(FileNotFoundError): + load_terms_from_config("non_existent_file.yaml") + + +def test_load_terms_from_config_validation_error(tmp_path): + config_data = { + "terms": [ + { + "key": "species", + "uri": "dwc:scientificName", + "label": 123, + }, # Invalid label type + ] + } + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + with pytest.raises(ValueError): + load_terms_from_config(config_file) + + +def test_load_terms_from_config_key_already_exists(tmp_path): + config_data = { + "terms": [ + { + "key": "call_type", + "uri": "dwc:scientificName", + "label": "Scientific Name", + }, # Duplicate key + ] + } + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + with pytest.raises(KeyError): + load_terms_from_config(config_file) diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index 6430d26..e69f107 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -30,8 +30,18 @@ def test_mix_examples( config = TrainPreprocessingConfig() - example1 = generate_train_example(clip_annotation_1, config) - example2 = generate_train_example(clip_annotation_2, config) + example1 = generate_train_example( + clip_annotation_1, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) + example2 = generate_train_example( + clip_annotation_2, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) mixed = mix_examples(example1, example2, config=config.preprocessing) @@ -59,8 +69,18 @@ def test_mix_examples_of_different_durations( config = TrainPreprocessingConfig() - example1 = generate_train_example(clip_annotation_1, config) - example2 = generate_train_example(clip_annotation_2, config) + example1 = generate_train_example( + clip_annotation_1, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) + example2 = generate_train_example( + clip_annotation_2, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) mixed = mix_examples(example1, example2, config=config.preprocessing) @@ -78,7 +98,12 @@ def test_add_echo( clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip_annotation_1 = data.ClipAnnotation(clip=clip1) config = TrainPreprocessingConfig() - original = generate_train_example(clip_annotation_1, config) + original = generate_train_example( + clip_annotation_1, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) with_echo = add_echo(original, config=config.preprocessing) assert with_echo["spectrogram"].shape == original["spectrogram"].shape @@ -94,7 +119,12 @@ def test_selected_random_subclip_has_the_correct_width( clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip_annotation_1 = data.ClipAnnotation(clip=clip1) config = TrainPreprocessingConfig() - original = generate_train_example(clip_annotation_1, config) + original = generate_train_example( + clip_annotation_1, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) subclip = select_subclip(original, width=100) assert subclip["spectrogram"].shape[1] == 100 @@ -107,7 +137,12 @@ def test_add_echo_after_subclip( clip1 = data.Clip(recording=recording1, start_time=0, end_time=1) clip_annotation_1 = data.ClipAnnotation(clip=clip1) config = TrainPreprocessingConfig() - original = generate_train_example(clip_annotation_1, config) + original = generate_train_example( + clip_annotation_1, + preprocessing_config=config.preprocessing, + target_config=config.target, + label_config=config.labels, + ) assert original.sizes["time"] > 512 diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index e86aae7..9b68f54 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr from soundevent import data -from batdetect2.train.labels import generate_heatmaps +from batdetect2.targets import generate_heatmaps recording = data.Recording( samplerate=256_000,