Create separate targets module

This commit is contained in:
mbsantiago 2025-04-12 16:48:40 +01:00
parent acda71ea45
commit b93d4c65c2
44 changed files with 1227 additions and 304 deletions

View File

@ -8,9 +8,11 @@ from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import ( from batdetect2.preprocess import (
load_preprocessing_config, load_preprocessing_config,
) )
from batdetect2.train import ( from batdetect2.targets import (
load_label_config, load_label_config,
load_target_config, load_target_config,
)
from batdetect2.train import (
preprocess_annotations, preprocess_annotations,
) )

View File

@ -12,10 +12,13 @@ from batdetect2.preprocess import (
STFTConfig, STFTConfig,
) )
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.terms import TagInfo from batdetect2.targets import (
from batdetect2.train.preprocess import (
HeatmapsConfig, HeatmapsConfig,
TagInfo,
TargetConfig, TargetConfig,
)
from batdetect2.targets.labels import LabelConfig
from batdetect2.train.preprocess import (
TrainPreprocessingConfig, TrainPreprocessingConfig,
) )
@ -91,11 +94,13 @@ def get_training_preprocessing_config(
for value in params["classes_to_ignore"] for value in params["classes_to_ignore"]
], ],
), ),
heatmaps=HeatmapsConfig( labels=LabelConfig(
position="bottom-left", heatmaps=HeatmapsConfig(
time_scale=1 / time_bin_width, position="bottom-left",
frequency_scale=1 / freq_bin_width, time_scale=1 / time_bin_width,
sigma=params["target_sigma"], frequency_scale=1 / freq_bin_width,
sigma=params["target_sigma"],
)
), ),
) )

View File

@ -1,12 +1,18 @@
from batdetect2.data.annotations import ( from batdetect2.data.annotations import (
AnnotatedDataset, AnnotatedDataset,
AOEFAnnotations,
BatDetect2FilesAnnotations,
BatDetect2MergedAnnotations,
load_annotated_dataset, load_annotated_dataset,
) )
from batdetect2.data.data import load_dataset, load_dataset_from_config from batdetect2.data.data import load_dataset, load_dataset_from_config
from batdetect2.data.types import Dataset from batdetect2.data.types import Dataset
__all__ = [ __all__ = [
"AOEFAnnotations",
"AnnotatedDataset", "AnnotatedDataset",
"BatDetect2FilesAnnotations",
"BatDetect2MergedAnnotations",
"Dataset", "Dataset",
"load_annotated_dataset", "load_annotated_dataset",
"load_dataset", "load_dataset",

View File

@ -1,4 +1,3 @@
import json
from pathlib import Path from pathlib import Path
from typing import Literal, Union from typing import Literal, Union
@ -32,5 +31,3 @@ AnnotationFormats = Union[
BatDetect2AnnotationFile, BatDetect2AnnotationFile,
AOEFAnnotationFile, AOEFAnnotationFile,
] ]

View File

@ -1,11 +1,9 @@
from pathlib import Path from pathlib import Path
from typing import Literal, Union
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
__all__ = [ __all__ = [
"AnnotatedDataset", "AnnotatedDataset",
"BatDetect2MergedAnnotations",
] ]
@ -37,5 +35,3 @@ class AnnotatedDataset(BaseConfig):
name: str name: str
audio_dir: Path audio_dir: Path
description: str = "" description: str = ""

View File

@ -0,0 +1,9 @@
from batdetect2.evaluate.evaluate import (
compute_error_auc,
match_predictions_and_annotations,
)
__all__ = [
"compute_error_auc",
"match_predictions_and_annotations",
]

View File

@ -1,13 +1,10 @@
from typing import List from typing import List
import numpy as np import numpy as np
import pandas as pd
from sklearn.metrics import auc, roc_curve from sklearn.metrics import auc, roc_curve
from soundevent import data from soundevent import data
from soundevent.evaluation import match_geometries from soundevent.evaluation import match_geometries
from batdetect2.train.targets import build_target_encoder, get_class_names
def match_predictions_and_annotations( def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
@ -51,20 +48,13 @@ def match_predictions_and_annotations(
return matches return matches
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
ret = []
for match in matches:
pass
def compute_error_auc(op_str, gt, pred, prob): def compute_error_auc(op_str, gt, pred, prob):
# classification error # classification error
pred_int = (pred > prob).astype(np.int32) pred_int = (pred > prob).astype(np.int32)
class_acc = (pred_int == gt).mean() * 100.0 class_acc = (pred_int == gt).mean() * 100.0
# ROC - area under curve # ROC - area under curve
fpr, tpr, thresholds = roc_curve(gt, pred) fpr, tpr, _ = roc_curve(gt, pred)
roc_auc = auc(fpr, tpr) roc_auc = auc(fpr, tpr)
print( print(
@ -177,7 +167,7 @@ def compute_pre_rec(
file_ids.append([pid] * valid_inds.sum()) file_ids.append([pid] * valid_inds.sum())
confidence = np.hstack(confidence) confidence = np.hstack(confidence)
file_ids = np.hstack(file_ids).astype(np.int) file_ids = np.hstack(file_ids).astype(int)
pred_boxes = np.vstack(pred_boxes) pred_boxes = np.vstack(pred_boxes)
if len(pred_class) > 0: if len(pred_class) > 0:
pred_class = np.hstack(pred_class) pred_class = np.hstack(pred_class)
@ -197,8 +187,7 @@ def compute_pre_rec(
# note, files with the incorrect duration will cause a problem # note, files with the incorrect duration will cause a problem
if (gg["start_times"] > file_dur).sum() > 0: if (gg["start_times"] > file_dur).sum() > 0:
print("Error: file duration incorrect for", gg["id"]) raise ValueError(f"Error: file duration incorrect for {gg['id']}")
assert False
boxes = np.vstack( boxes = np.vstack(
( (
@ -244,6 +233,8 @@ def compute_pre_rec(
gt_id = file_ids[ind] gt_id = file_ids[ind]
valid_det = False valid_det = False
det_ind = 0
if gt_boxes[gt_id].shape[0] > 0: if gt_boxes[gt_id].shape[0] > 0:
# compute overlap # compute overlap
valid_det, det_ind = compute_affinity_1d( valid_det, det_ind = compute_affinity_1d(
@ -273,7 +264,7 @@ def compute_pre_rec(
# store threshold values - used for plotting # store threshold values - used for plotting
conf_sorted = np.sort(confidence)[::-1][valid_inds] conf_sorted = np.sort(confidence)[::-1][valid_inds]
thresholds = np.linspace(0.1, 0.9, 9) thresholds = np.linspace(0.1, 0.9, 9)
thresholds_inds = np.zeros(len(thresholds), dtype=np.int) thresholds_inds = np.zeros(len(thresholds), dtype=int)
for ii, tt in enumerate(thresholds): for ii, tt in enumerate(thresholds):
thresholds_inds[ii] = np.argmin(conf_sorted > tt) thresholds_inds[ii] = np.argmin(conf_sorted > tt)
thresholds_inds[thresholds_inds == 0] = -1 thresholds_inds[thresholds_inds == 0] = -1
@ -385,7 +376,7 @@ def compute_file_accuracy(gts, preds, num_classes):
).mean(0) ).mean(0)
best_thresh = np.argmax(acc_per_thresh) best_thresh = np.argmax(acc_per_thresh)
best_acc = acc_per_thresh[best_thresh] best_acc = acc_per_thresh[best_thresh]
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist() pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist()
res = {} res = {}
res["num_valid_files"] = len(gt_valid) res["num_valid_files"] = len(gt_valid)

View File

@ -1,92 +1,26 @@
from enum import Enum
from typing import Optional, Tuple
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
Net2DFast, Net2DFast,
Net2DFastNoAttn, Net2DFastNoAttn,
Net2DFastNoCoordConv, Net2DFastNoCoordConv,
Net2DPlain, Net2DPlain,
) )
from batdetect2.models.build import build_architecture
from batdetect2.models.config import ModelConfig, ModelType, load_model_config
from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.typing import BackboneModel from batdetect2.models.types import BackboneModel, ModelOutput
__all__ = [ __all__ = [
"BBoxHead", "BBoxHead",
"BackboneModel",
"ClassifierHead", "ClassifierHead",
"ModelConfig", "ModelConfig",
"ModelOutput",
"ModelType", "ModelType",
"Net2DFast", "Net2DFast",
"Net2DFastNoAttn", "Net2DFastNoAttn",
"Net2DFastNoCoordConv", "Net2DFastNoCoordConv",
"Net2DPlain",
"build_architecture",
"build_architecture", "build_architecture",
"load_model_config", "load_model_config",
] ]
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)
def build_architecture(
config: Optional[ModelConfig] = None,
) -> BackboneModel:
config = config or ModelConfig()
if config.name == ModelType.Net2DFast:
return Net2DFast(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
raise ValueError(f"Unknown model type: {config.name}")

View File

@ -12,12 +12,13 @@ from batdetect2.models.blocks import (
UpscalingLayer, UpscalingLayer,
VerticalConv, VerticalConv,
) )
from batdetect2.models.typing import BackboneModel from batdetect2.models.types import BackboneModel
__all__ = [ __all__ = [
"Net2DFast", "Net2DFast",
"Net2DFastNoAttn", "Net2DFastNoAttn",
"Net2DFastNoCoordConv", "Net2DFastNoCoordConv",
"Net2DPlain",
] ]
@ -165,7 +166,6 @@ def pad_adjust(
spec: torch.Tensor, spec: torch.Tensor,
factor: int = 32, factor: int = 32,
) -> Tuple[torch.Tensor, int, int]: ) -> Tuple[torch.Tensor, int, int]:
print(spec.shape)
h, w = spec.shape[2:] h, w = spec.shape[2:]
h_pad = -h % factor h_pad = -h % factor
w_pad = -w % factor w_pad = -w % factor

View File

@ -0,0 +1,58 @@
from typing import Optional
from batdetect2.models.backbones import (
Net2DFast,
Net2DFastNoAttn,
Net2DFastNoCoordConv,
Net2DPlain,
)
from batdetect2.models.config import ModelConfig, ModelType
from batdetect2.models.types import BackboneModel
__all__ = [
"build_architecture",
]
def build_architecture(
config: Optional[ModelConfig] = None,
) -> BackboneModel:
config = config or ModelConfig()
if config.name == ModelType.Net2DFast:
return Net2DFast(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
raise ValueError(f"Unknown model type: {config.name}")

View File

@ -0,0 +1,34 @@
from enum import Enum
from typing import Optional, Tuple
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
__all__ = [
"ModelType",
"ModelConfig",
"load_model_config",
]
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)

View File

@ -1,15 +0,0 @@
import sys
from typing import Iterable, List, Literal, Sequence
import torch
from torch import nn
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -1,15 +0,0 @@
import sys
from typing import Iterable, List, Literal, Sequence
import torch
from torch import nn
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -7,31 +7,27 @@ from pydantic import Field
from soundevent import data from soundevent import data
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
from batdetect2.models import ( from batdetect2.models import (
BBoxHead, BBoxHead,
ClassifierHead, ClassifierHead,
ModelConfig, ModelConfig,
ModelOutput,
build_architecture, build_architecture,
) )
from batdetect2.models.typing import ModelOutput
from batdetect2.post_process import ( from batdetect2.post_process import (
PostprocessConfig, PostprocessConfig,
postprocess_model_outputs, postprocess_model_outputs,
) )
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.train.config import TrainingConfig from batdetect2.targets import (
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.losses import compute_loss
from batdetect2.train.targets import (
TargetConfig, TargetConfig,
build_decoder, build_decoder,
build_target_encoder, build_target_encoder,
get_class_names, get_class_names,
) )
from batdetect2.train import TrainExample, TrainingConfig, compute_loss
__all__ = [ __all__ = [
"DetectorModel", "DetectorModel",
@ -83,12 +79,9 @@ class DetectorModel(L.LightningModule):
replacement_rules=self.config.targets.replace, replacement_rules=self.config.targets.replace,
) )
self.decoder = build_decoder(self.config.targets.classes) self.decoder = build_decoder(self.config.targets.classes)
self.validation_predictions = []
self.example_input_array = torch.randn([1, 1, 128, 512]) self.example_input_array = torch.randn([1, 1, 128, 512])
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore def forward(self, spec: torch.Tensor) -> ModelOutput:
features = self.backbone(spec) features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features) detection_probs, classification_probs = self.classifier(features)
size_preds = self.bbox(features) size_preds = self.bbox(features)
@ -130,27 +123,6 @@ class DetectorModel(L.LightningModule):
self.log("val/loss/size", losses.total, logger=True) self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True) self.log("val/loss/classification", losses.total, logger=True)
dataloaders = self.trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
clip_prediction = postprocess_model_outputs(
outputs,
clips=[clip_annotation.clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]
matches = match_predictions_and_annotations(
clip_annotation,
clip_prediction,
)
self.validation_predictions.extend(matches)
def on_validation_epoch_end(self) -> None: def on_validation_epoch_end(self) -> None:
self.validation_predictions.clear() self.validation_predictions.clear()

View File

@ -9,7 +9,7 @@ from soundevent import data
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.typing import ModelOutput from batdetect2.models.types import ModelOutput
__all__ = [ __all__ = [
"PostprocessConfig", "PostprocessConfig",

View File

View File

@ -0,0 +1,73 @@
import numpy as np
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.models import ModelOutput
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
def to_xarray(
output: ModelOutput,
start_time: float,
end_time: float,
class_names: list[str],
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
):
detection = output.detection_probs
size = output.size_preds
classes = output.class_probs
features = output.features
if len(detection.shape) == 4:
if detection.shape[0] != 1:
raise ValueError(
"Expected a non-batched output or a batch of size 1, instead "
f"got an input of shape {detection.shape}"
)
detection = detection.squeeze(dim=0)
size = size.squeeze(dim=0)
classes = classes.squeeze(dim=0)
features = features.squeeze(dim=0)
_, width, height = detection.shape
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
if classes.shape[0] != len(class_names):
raise ValueError(
f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })"
)
return xr.Dataset(
data_vars={
"detection": (
[Dimensions.time.value, Dimensions.frequency.value],
detection.squeeze(dim=0).detach().numpy(),
),
"size": (
[
"dimension",
Dimensions.time.value,
Dimensions.frequency.value,
],
detection.detach().numpy(),
),
"classes": (
[
"category",
Dimensions.time.value,
Dimensions.frequency.value,
],
classes.detach().numpy(),
),
},
coords={
Dimensions.time.value: times,
Dimensions.frequency.value: freqs,
"dimension": ["width", "height"],
"category": class_names,
},
)

View File

@ -0,0 +1,32 @@
from typing import Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
__all__ = [
"PostprocessConfig",
"load_postprocess_config",
]
NMS_KERNEL_SIZE = 9
DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200
class PostprocessConfig(BaseConfig):
"""Configuration for postprocessing model outputs."""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
min_freq: int = Field(default=10000, gt=0)
max_freq: int = Field(default=120000, gt=0)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
return load_config(path, schema=PostprocessConfig, field=field)

View File

@ -0,0 +1,50 @@
from typing import Tuple, Union
import torch
NMS_KERNEL_SIZE = 9
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Run non-maximum suppression on a tensor.
This function removes values from the input tensor that are not local
maxima in the neighborhood of the given kernel size.
All non-maximum values are set to zero.
Parameters
----------
tensor : torch.Tensor
Input tensor.
kernel_size : Union[int, Tuple[int, int]], optional
Size of the neighborhood to consider for non-maximum suppression.
If an integer is given, the neighborhood will be a square of the
given size. If a tuple is given, the neighborhood will be a
rectangle with the given height and width.
Returns
-------
torch.Tensor
Tensor with non-maximum suppressed values.
"""
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
hmax = torch.nn.functional.max_pool2d(
tensor,
(kernel_size_h, kernel_size_w),
stride=1,
padding=(pad_h, pad_w),
)
keep = (hmax == tensor).float()
return tensor * keep

View File

@ -0,0 +1,17 @@
from typing import Dict, NamedTuple
import numpy as np
__all__ = [
"BatDetect2Prediction",
]
class BatDetect2Prediction(NamedTuple):
start_time: float
end_time: float
low_freq: float
high_freq: float
detection_score: float
class_scores: Dict[str, float]
features: np.ndarray

View File

@ -15,6 +15,8 @@ from batdetect2.preprocess.config import (
load_preprocessing_config, load_preprocessing_config,
) )
from batdetect2.preprocess.spectrogram import ( from batdetect2.preprocess.spectrogram import (
MAX_FREQ,
MIN_FREQ,
AmplitudeScaleConfig, AmplitudeScaleConfig,
FrequencyConfig, FrequencyConfig,
LogScaleConfig, LogScaleConfig,
@ -31,6 +33,8 @@ __all__ = [
"AudioConfig", "AudioConfig",
"FrequencyConfig", "FrequencyConfig",
"LogScaleConfig", "LogScaleConfig",
"MAX_FREQ",
"MIN_FREQ",
"PcenScaleConfig", "PcenScaleConfig",
"PreprocessingConfig", "PreprocessingConfig",
"ResampleConfig", "ResampleConfig",

View File

@ -31,7 +31,7 @@ def load_file_audio(
path: data.PathLike, path: data.PathLike,
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
recording = data.Recording.from_file(path) recording = data.Recording.from_file(path)
return load_recording_audio( return load_recording_audio(
@ -46,7 +46,7 @@ def load_recording_audio(
recording: data.Recording, recording: data.Recording,
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
clip = data.Clip( clip = data.Clip(
recording=recording, recording=recording,
@ -65,7 +65,7 @@ def load_clip_audio(
clip: data.Clip, clip: data.Clip,
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
config = config or AudioConfig() config = config or AudioConfig()
@ -122,7 +122,7 @@ def resample_audio(
wav: xr.DataArray, wav: xr.DataArray,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
mode: str = "poly", mode: str = "poly",
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
if "time" not in wav.dims: if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension") raise ValueError("Audio must have a time dimension")

View File

@ -11,6 +11,20 @@ from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
__all__ = [
"STFTConfig",
"FrequencyConfig",
"LogScaleConfig",
"PcenScaleConfig",
"AmplitudeScaleConfig",
"Scales",
"SpectrogramConfig",
"compute_spectrogram",
]
MIN_FREQ = 10_000
MAX_FREQ = 120_000
class STFTConfig(BaseConfig): class STFTConfig(BaseConfig):
window_duration: float = Field(default=0.002, gt=0) window_duration: float = Field(default=0.002, gt=0)
@ -70,7 +84,7 @@ class SpectrogramConfig(BaseConfig):
def compute_spectrogram( def compute_spectrogram(
wav: xr.DataArray, wav: xr.DataArray,
config: Optional[SpectrogramConfig] = None, config: Optional[SpectrogramConfig] = None,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
config = config or SpectrogramConfig() config = config or SpectrogramConfig()
@ -124,7 +138,7 @@ def stft(
window_duration: float, window_duration: float,
window_overlap: float, window_overlap: float,
window_fn: str = "hann", window_fn: str = "hann",
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
start_time, end_time = arrays.get_dim_range(wave, dim="time") start_time, end_time = arrays.get_dim_range(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time") step = arrays.get_dim_step(wave, dim="time")
@ -190,7 +204,7 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
def scale_spectrogram( def scale_spectrogram(
spec: xr.DataArray, spec: xr.DataArray,
scale: Scales, scale: Scales,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
if scale.name == "log": if scale.name == "log":
return scale_log(spec, dtype=dtype) return scale_log(spec, dtype=dtype)
@ -230,7 +244,7 @@ def scale_pcen(
def scale_log( def scale_log(
spec: xr.DataArray, spec: xr.DataArray,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"] samplerate = spec.attrs["original_samplerate"]
nfft = spec.attrs["nfft"] nfft = spec.attrs["nfft"]

View File

@ -0,0 +1,49 @@
"""Module that helps define what the training targets are.
The goal of this module is to configure how raw sound event annotations (tags)
are processed to determine which events are relevant for training and what
specific class label each relevant event should receive. Also, define how
predicted class labels map back to the original tag system for interpretation.
"""
from batdetect2.targets.labels import (
HeatmapsConfig,
LabelConfig,
generate_heatmaps,
load_label_config,
)
from batdetect2.targets.targets import (
TargetConfig,
build_decoder,
build_sound_event_filter,
build_target_encoder,
filter_sound_event,
get_class_names,
load_target_config,
)
from batdetect2.targets.terms import (
TagInfo,
TermInfo,
call_type,
get_tag_from_info,
individual,
)
__all__ = [
"HeatmapsConfig",
"LabelConfig",
"TagInfo",
"TargetConfig",
"TermInfo",
"build_decoder",
"build_sound_event_filter",
"build_target_encoder",
"call_type",
"filter_sound_event",
"generate_heatmaps",
"get_class_names",
"get_tag_from_info",
"individual",
"load_label_config",
"load_target_config",
]

View File

@ -7,7 +7,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.terms import TagInfo, get_tag_from_info from batdetect2.targets.terms import TagInfo, get_tag_from_info
__all__ = [ __all__ = [
"TargetConfig", "TargetConfig",

370
batdetect2/targets/terms.py Normal file
View File

@ -0,0 +1,370 @@
"""Manages the vocabulary (Terms and Tags) for defining training targets.
This module provides the necessary tools to declare, register, and manage the
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
sub-package. It establishes a consistent vocabulary for filtering,
transforming, and classifying sound events based on their annotations (Tags).
The core component is the `TermRegistry`, which maps unique string keys
(aliases) to specific `Term` definitions. This allows users to refer to complex
terms using simple, consistent keys in configuration files and code.
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
programmatically, or loaded from external configuration files (e.g., YAML).
"""
from inspect import getmembers
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from soundevent import data, terms
from batdetect2.configs import load_config
__all__ = [
"call_type",
"individual",
"get_tag_from_info",
"TermInfo",
"TagInfo",
]
# The default key used to reference the 'generic_class' term.
# Often used implicitly when defining classification targets.
GENERIC_CLASS_KEY = "class"
call_type = data.Term(
name="soundevent:call_type",
label="Call Type",
definition=(
"A broad categorization of animal vocalizations based on their "
"intended function or purpose (e.g., social, distress, mating, "
"territorial, echolocation)."
),
)
"""Term representing the broad functional category of a vocalization."""
individual = data.Term(
name="soundevent:individual",
label="Individual",
definition=(
"An id for an individual animal. In the context of bioacoustic "
"annotation, this term is used to label vocalizations that are "
"attributed to a specific individual."
),
)
"""Term used for tags identifying a specific individual animal."""
generic_class = data.Term(
name="soundevent:class",
label="Class",
definition=(
"A generic term representing the name of a class within a "
"classification model. Its specific meaning is determined by "
"the model's application."
),
)
"""Generic term representing a classification model's output class label."""
class TermRegistry:
"""Manages a registry mapping unique keys to Term definitions.
This class acts as the central repository for the vocabulary of terms
used within the target definition process. It allows registering terms
with simple string keys and retrieving them consistently.
"""
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
"""Initializes the TermRegistry.
Args:
terms: An optional dictionary of initial key-to-Term mappings
to populate the registry with. Defaults to an empty registry.
"""
self._terms = terms or {}
def add_term(self, key: str, term: data.Term) -> None:
"""Adds a Term object to the registry with the specified key.
Args:
key: The unique string key to associate with the term.
term: The soundevent.data.Term object to register.
Raises:
KeyError: If a term with the provided key already exists in the
registry.
"""
if key in self._terms:
raise KeyError("A term with the provided key already exists.")
self._terms[key] = term
def get_term(self, key: str) -> data.Term:
"""Retrieves a registered term by its unique key.
Args:
key: The unique string key of the term to retrieve.
Returns:
The corresponding soundevent.data.Term object.
Raises:
KeyError: If no term with the specified key is found, with a
helpful message suggesting listing available keys.
"""
try:
return self._terms[key]
except KeyError as err:
raise KeyError(
"No term found for key "
f"'{key}'. Ensure it is registered or loaded. "
"Use `get_term_keys()` to list available terms."
) from err
def add_custom_term(
self,
key: str,
name: Optional[str] = None,
uri: Optional[str] = None,
label: Optional[str] = None,
definition: Optional[str] = None,
) -> data.Term:
"""Creates a new Term from attributes and adds it to the registry.
This is useful for defining terms directly in code or when loading
from configuration files where only attributes are provided.
If optional fields (`name`, `label`, `definition`) are not provided,
reasonable defaults are used (`key` for name/label, "Unknown" for
definition).
Args:
key: The unique string key for the new term.
name: The name for the new term (defaults to `key`).
uri: The URI for the new term (optional).
label: The display label for the new term (defaults to `key`).
definition: The definition for the new term (defaults to "Unknown").
Returns:
The newly created and registered soundevent.data.Term object.
Raises:
KeyError: If a term with the provided key already exists.
"""
term = data.Term(
name=name or key,
label=label or key,
uri=uri,
definition=definition or "Unknown",
)
self.add_term(key, term)
return term
def get_keys(self) -> List[str]:
"""Returns a list of all keys currently registered.
Returns:
A list of strings representing the keys of all registered terms.
"""
return list(self._terms.keys())
registry = TermRegistry(
terms=dict(
[
*getmembers(terms, lambda x: isinstance(x, data.Term)),
("call_type", call_type),
("individual", individual),
(GENERIC_CLASS_KEY, generic_class),
]
)
)
"""The default, globally accessible TermRegistry instance.
It is pre-populated with standard terms from `soundevent.terms` and common
terms defined in this module (`call_type`, `individual`, `generic_class`).
Functions in this module use this registry by default unless another instance
is explicitly passed.
"""
def get_term_from_key(
key: str,
term_registry: TermRegistry = registry,
) -> data.Term:
"""Convenience function to retrieve a term by key from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Args:
key: The unique key of the term to retrieve.
term_registry: The TermRegistry instance to search in. Defaults to
the global `registry`.
Returns:
The corresponding soundevent.data.Term object.
Raises:
KeyError: If the key is not found in the specified registry.
"""
return term_registry.get_term(key)
def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
"""Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Args:
term_registry: The TermRegistry instance to query. Defaults to
the global `registry`.
Returns:
A list of strings representing the keys of all registered terms.
"""
return term_registry.get_keys()
class TagInfo(BaseModel):
"""Represents information needed to define a specific Tag.
This model is typically used in configuration files (e.g., YAML) to
specify tags used for filtering, target class definition, or associating
tags with output classes. It links a tag value to a term definition
via the term's registry key.
Attributes:
value: The value of the tag (e.g., "Myotis myotis", "Echolocation").
key: The key (alias) of the term associated with this tag, as
registered in the TermRegistry. Defaults to "class", implying
it represents a classification target label by default.
"""
value: str
key: str = GENERIC_CLASS_KEY
def get_tag_from_info(
tag_info: TagInfo,
term_registry: TermRegistry = registry,
) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
Looks up the term using the key in the provided `tag_info` from the
specified registry and constructs a Tag object.
Args:
tag_info: The TagInfo object containing the value and term key.
term_registry: The TermRegistry instance to use for term lookup.
Defaults to the global `registry`.
Returns:
A soundevent.data.Tag object corresponding to the input info.
Raises:
KeyError: If the term key specified in `tag_info.key` is not found
in the registry.
"""
term = get_term_from_key(tag_info.key, term_registry=term_registry)
return data.Tag(term=term, value=tag_info.value)
class TermInfo(BaseModel):
"""Represents the definition of a Term within a configuration file.
This model allows users to define custom terms directly in configuration
files (e.g., YAML) which can then be loaded into the TermRegistry.
It mirrors the parameters of `TermRegistry.add_custom_term`.
Attributes:
key: The unique key (alias) that will be used to register and
reference this term.
label: The optional display label for the term. Defaults to `key`
if not provided during registration.
name: The optional formal name for the term. Defaults to `key`
if not provided during registration.
uri: The optional URI identifying the term (e.g., from a standard
vocabulary).
definition: The optional textual definition of the term. Defaults to
"Unknown" if not provided during registration.
"""
key: str
label: Optional[str] = None
name: Optional[str] = None
uri: Optional[str] = None
definition: Optional[str] = None
class TermConfig(BaseModel):
"""Pydantic schema for loading a list of term definitions from config.
This model typically corresponds to a section in a configuration file
(e.g., YAML) containing a list of term definitions to be registered.
Example YAML structure:
```yaml
terms:
- key: species
uri: dwc:scientificName
label: Scientific Name
- key: my_custom_term
name: My Custom Term
definition: Describes a specific project attribute.
# ... more TermInfo definitions
```
Attributes:
terms: A list of TermInfo objects, each defining a term to be
registered. Defaults to an empty list.
"""
terms: List[TermInfo] = Field(default_factory=list)
def load_terms_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = registry,
) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them.
Parses a configuration file (e.g., YAML) using the TermConfig schema,
extracts the list of TermInfo definitions, and adds each one as a
custom term to the specified TermRegistry instance.
Args:
path: The path to the configuration file.
field: Optional key indicating a specific section within the config
file where the 'terms' list is located. If None, expects the
list directly at the top level or within a structure matching
TermConfig schema.
term_registry: The TermRegistry instance to add the loaded terms to.
Defaults to the global `registry`.
Returns:
A dictionary mapping the keys of the newly added terms to their
corresponding Term objects.
Raises:
FileNotFoundError: If the config file path does not exist.
ValidationError: If the config file structure does not match the
TermConfig schema.
KeyError: If a term key loaded from the config conflicts with a key
already present in the registry.
"""
data = load_config(path, schema=TermConfig, field=field)
return {
info.key: term_registry.add_custom_term(
info.key,
name=info.name,
uri=info.uri,
label=info.label,
definition=info.definition,
)
for info in data.terms
}

View File

@ -1,88 +0,0 @@
from inspect import getmembers
from typing import Optional
from pydantic import BaseModel
from soundevent import data, terms
__all__ = [
"call_type",
"individual",
"get_term_from_info",
"get_tag_from_info",
"TermInfo",
"TagInfo",
]
class TermInfo(BaseModel):
label: Optional[str]
name: Optional[str]
uri: Optional[str]
class TagInfo(BaseModel):
value: str
term: Optional[TermInfo] = None
key: Optional[str] = None
label: Optional[str] = None
call_type = data.Term(
name="soundevent:call_type",
label="Call Type",
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
)
individual = data.Term(
name="soundevent:individual",
label="Individual",
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
)
ALL_TERMS = [
*getmembers(terms, lambda x: isinstance(x, data.Term)),
call_type,
individual,
]
def get_term_from_info(term_info: TermInfo) -> data.Term:
for term in ALL_TERMS:
if term_info.name and term_info.name == term.name:
return term
if term_info.label and term_info.label == term.label:
return term
if term_info.uri and term_info.uri == term.uri:
return term
if term_info.name is None:
if term_info.label is None:
raise ValueError("At least one of name or label must be provided.")
term_info.name = (
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
)
if term_info.label is None:
term_info.label = term_info.name
return data.Term(
name=term_info.name,
label=term_info.label,
uri=term_info.uri,
definition="Unknown",
)
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
if tag_info.term:
term = get_term_from_info(tag_info.term)
elif tag_info.key:
term = data.term_from_key(tag_info.key)
else:
raise ValueError("Either term or key must be provided in tag info.")
return data.Tag(term=term, value=tag_info.value)

View File

@ -17,37 +17,26 @@ from batdetect2.train.dataset import (
TrainExample, TrainExample,
get_preprocessed_files, get_preprocessed_files,
) )
from batdetect2.train.labels import LabelConfig, load_label_config from batdetect2.train.losses import compute_loss
from batdetect2.train.preprocess import ( from batdetect2.train.preprocess import (
generate_train_example, generate_train_example,
preprocess_annotations, preprocess_annotations,
) )
from batdetect2.train.targets import (
TagInfo,
TargetConfig,
build_target_encoder,
load_target_config,
)
from batdetect2.train.train import TrainerConfig, load_trainer_config, train from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [ __all__ = [
"AugmentationsConfig", "AugmentationsConfig",
"LabelConfig",
"LabeledDataset", "LabeledDataset",
"SubclipConfig", "SubclipConfig",
"TagInfo",
"TargetConfig",
"TrainExample", "TrainExample",
"TrainerConfig", "TrainerConfig",
"TrainingConfig", "TrainingConfig",
"add_echo", "add_echo",
"augment_example", "augment_example",
"build_target_encoder", "compute_loss",
"generate_train_example", "generate_train_example",
"get_preprocessed_files", "get_preprocessed_files",
"load_agumentation_config", "load_agumentation_config",
"load_label_config",
"load_target_config",
"load_train_config", "load_train_config",
"load_trainer_config", "load_trainer_config",
"mask_frequency", "mask_frequency",

View File

@ -0,0 +1,55 @@
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from torch.utils.data import DataLoader
from batdetect2.evaluate import match_predictions_and_annotations
from batdetect2.post_process import postprocess_model_outputs
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.types import ModelOutput
class ValidationMetrics(Callback):
def __init__(self):
super().__init__()
self.predictions = []
def on_validation_epoch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.predictions = []
return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: ModelOutput,
batch: TrainExample,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
clip_prediction = postprocess_model_outputs(
outputs,
clips=[clip_annotation.clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]
matches = match_predictions_and_annotations(
clip_annotation,
clip_prediction,
)
self.validation_predictions.extend(matches)
return super().on_validation_batch_end(
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
)

View File

@ -6,7 +6,7 @@ from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.models.typing import DetectionModel from batdetect2.models.types import DetectionModel
from batdetect2.train.dataset import LabeledDataset from batdetect2.train.dataset import LabeledDataset

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from pydantic import Field from pydantic import Field
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.typing import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.train.dataset import TrainExample from batdetect2.train.dataset import TrainExample
__all__ = [ __all__ = [

View File

@ -17,11 +17,12 @@ from batdetect2.preprocess import (
compute_spectrogram, compute_spectrogram,
load_clip_audio, load_clip_audio,
) )
from batdetect2.train.labels import LabelConfig, generate_heatmaps from batdetect2.targets import (
from batdetect2.train.targets import ( LabelConfig,
TargetConfig, TargetConfig,
build_target_encoder,
build_sound_event_filter, build_sound_event_filter,
build_target_encoder,
generate_heatmaps,
get_class_names, get_class_names,
) )

View File

@ -91,7 +91,15 @@ convention = "numpy"
[tool.pyright] [tool.pyright]
include = ["batdetect2", "tests"] include = ["batdetect2", "tests"]
venvPath = "."
venv = ".venv"
pythonVersion = "3.9" pythonVersion = "3.9"
pythonPlatform = "All" pythonPlatform = "All"
exclude = [
"batdetect2/detector/",
"batdetect2/finetune",
"batdetect2/utils",
"batdetect2/plotting",
"batdetect2/plot",
"batdetect2/api",
"batdetect2/evaluate/legacy",
"batdetect2/train/legacy",
]

View File

@ -1,11 +1,18 @@
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional from typing import Callable, Iterable, List, Optional
import numpy as np import numpy as np
import pytest import pytest
import soundfile as sf import soundfile as sf
from soundevent import data from soundevent import data, terms
from batdetect2.targets import (
TargetConfig,
build_target_encoder,
call_type,
get_class_names,
)
@pytest.fixture @pytest.fixture
@ -107,3 +114,102 @@ def recording_factory(wav_factory: Callable[..., Path]):
) )
return _recording_factory return _recording_factory
@pytest.fixture
def recording(
recording_factory: Callable[..., data.Recording],
) -> data.Recording:
return recording_factory()
@pytest.fixture
def clip(recording: data.Recording) -> data.Clip:
return data.Clip(recording=recording, start_time=0, end_time=0.5)
@pytest.fixture
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
recording=recording,
),
tags=[
data.Tag(term=terms.scientific_name, value="Myotis myotis"),
data.Tag(term=call_type, value="Echolocation"),
],
)
@pytest.fixture
def generic_call(recording: data.Recording) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(
coordinates=[0.34, 35_000, 0.348, 62_000]
),
recording=recording,
),
tags=[
data.Tag(term=terms.order, value="Chiroptera"),
data.Tag(term=call_type, value="Echolocation"),
],
)
@pytest.fixture
def non_relevant_sound_event(
recording: data.Recording,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(
coordinates=[0.22, 50_000, 0.24, 58_000]
),
recording=recording,
),
tags=[
data.Tag(
term=terms.scientific_name,
value="Muscardinus avellanarius",
),
],
)
@pytest.fixture
def clip_annotation(
clip: data.Clip,
echolocation_call: data.SoundEventAnnotation,
generic_call: data.SoundEventAnnotation,
non_relevant_sound_event: data.SoundEventAnnotation,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=clip,
sound_events=[
echolocation_call,
generic_call,
non_relevant_sound_event,
],
)
@pytest.fixture
def target_config() -> TargetConfig:
return TargetConfig()
@pytest.fixture
def class_names(target_config: TargetConfig) -> List[str]:
return get_class_names(target_config.classes)
@pytest.fixture
def encoder(
target_config: TargetConfig,
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
return build_target_encoder(
classes=target_config.classes,
replacement_rules=target_config.replace,
)

View File

@ -4,7 +4,7 @@ from pathlib import Path
from soundevent import data from soundevent import data
from batdetect2.compat.data import load_annotation_project_from_dir from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset
ROOT_DIR = Path(__file__).parent.parent.parent ROOT_DIR = Path(__file__).parent.parent.parent
@ -12,8 +12,14 @@ ROOT_DIR = Path(__file__).parent.parent.parent
def test_load_example_annotation_project(): def test_load_example_annotation_project():
path = ROOT_DIR / "example_data" / "anns" path = ROOT_DIR / "example_data" / "anns"
audio_dir = ROOT_DIR / "example_data" / "audio" audio_dir = ROOT_DIR / "example_data" / "audio"
project = load_annotation_project_from_dir(path, audio_dir=audio_dir) project = load_annotated_dataset(
BatDetect2FilesAnnotations(
name="test",
audio_dir=audio_dir,
annotations_dir=path,
)
)
assert isinstance(project, data.AnnotationProject) assert isinstance(project, data.AnnotationProject)
assert project.name == str(path) assert project.name == "test"
assert len(project.clip_annotations) == 3 assert len(project.clip_annotations) == 3
assert len(project.tasks) == 3 assert len(project.tasks) == 3

View File

@ -5,8 +5,8 @@ from typing import List
import numpy as np import numpy as np
import pytest import pytest
from batdetect2.compat.data import load_annotation_project_from_dir
from batdetect2.compat.params import get_training_preprocessing_config from batdetect2.compat.params import get_training_preprocessing_config
from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
@ -26,6 +26,8 @@ def test_can_generate_similar_training_inputs(
old_parameters = json.loads((regression_dir / "params.json").read_text()) old_parameters = json.loads((regression_dir / "params.json").read_text())
config = get_training_preprocessing_config(old_parameters) config = get_training_preprocessing_config(old_parameters)
assert config is not None
for audio_file in example_audio_files: for audio_file in example_audio_files:
example_file = regression_dir / f"{audio_file.name}.npz" example_file = regression_dir / f"{audio_file.name}.npz"
@ -36,9 +38,12 @@ def test_can_generate_similar_training_inputs(
size_mask = dataset["size_mask"] size_mask = dataset["size_mask"]
class_mask = dataset["class_mask"] class_mask = dataset["class_mask"]
project = load_annotation_project_from_dir( project = load_annotated_dataset(
example_anns_dir, BatDetect2FilesAnnotations(
audio_dir=example_audio_dir, name="test",
annotations_dir=example_anns_dir,
audio_dir=example_audio_dir,
)
) )
clip_annotation = next( clip_annotation = next(
@ -47,7 +52,12 @@ def test_can_generate_similar_training_inputs(
if ann.clip.recording.path == audio_file if ann.clip.recording.path == audio_file
) )
new_dataset = generate_train_example(clip_annotation, config) new_dataset = generate_train_example(
clip_annotation,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
new_spec = new_dataset["spectrogram"].values new_spec = new_dataset["spectrogram"].values
new_detection_mask = new_dataset["detection"].values new_detection_mask = new_dataset["detection"].values
new_size_mask = new_dataset["size"].values new_size_mask = new_dataset["size"].values

View File

View File

@ -0,0 +1,43 @@
from typing import List
import numpy as np
import torch
import xarray as xr
from soundevent import data
from batdetect2.modules import DetectorModel
from batdetect2.postprocess.arrays import to_xarray
from batdetect2.preprocess import preprocess_audio_clip
def test_this(clip: data.Clip, class_names: List[str]):
spec = xr.DataArray(
data=np.random.rand(100, 100),
dims=["time", "frequency"],
coords={
"time": np.linspace(0, 100, 100, endpoint=False),
"frequency": np.linspace(0, 100, 100, endpoint=False),
},
)
model = DetectorModel()
spec = preprocess_audio_clip(
clip,
config=model.config.preprocessing,
)
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
outputs = model(tensor)
arrays = to_xarray(
outputs,
start_time=clip.start_time,
end_time=clip.end_time,
class_names=class_names,
)
print(arrays)
assert False

View File

View File

@ -0,0 +1,175 @@
import pytest
import yaml
from soundevent import data
from batdetect2.targets import terms
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
load_terms_from_config,
)
def test_term_registry_initialization():
registry = TermRegistry()
assert registry._terms == {}
initial_terms = {
"test_term": data.Term(name="test", label="Test", definition="test")
}
registry = TermRegistry(terms=initial_terms)
assert registry._terms == initial_terms
def test_term_registry_add_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
assert registry._terms["test_key"] == term
def test_term_registry_get_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
retrieved_term = registry.get_term("test_key")
assert retrieved_term == term
def test_term_registry_add_custom_term():
registry = TermRegistry()
term = registry.add_custom_term(
"custom_key", name="custom", label="Custom", definition="A custom term"
)
assert registry._terms["custom_key"] == term
assert term.name == "custom"
assert term.label == "Custom"
assert term.definition == "A custom term"
def test_term_registry_add_duplicate_term():
registry = TermRegistry()
term = data.Term(name="test", label="Test", definition="test")
registry.add_term("test_key", term)
with pytest.raises(KeyError):
registry.add_term("test_key", term)
def test_term_registry_get_term_not_found():
registry = TermRegistry()
with pytest.raises(KeyError):
registry.get_term("non_existent_key")
def test_term_registry_get_keys():
registry = TermRegistry()
term1 = data.Term(name="test1", label="Test1", definition="test")
term2 = data.Term(name="test2", label="Test2", definition="test")
registry.add_term("key1", term1)
registry.add_term("key2", term2)
keys = registry.get_keys()
assert set(keys) == {"key1", "key2"}
def test_get_term_from_key():
term = terms.get_term_from_key("call_type")
assert term == terms.call_type
custom_registry = TermRegistry()
custom_term = data.Term(name="custom", label="Custom", definition="test")
custom_registry.add_term("custom_key", custom_term)
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
assert term == custom_term
def test_get_term_keys():
keys = terms.get_term_keys()
assert "call_type" in keys
assert "individual" in keys
assert terms.GENERIC_CLASS_KEY in keys
custom_registry = TermRegistry()
custom_term = data.Term(name="custom", label="Custom", definition="test")
custom_registry.add_term("custom_key", custom_term)
keys = terms.get_term_keys(term_registry=custom_registry)
assert "custom_key" in keys
def test_tag_info_and_get_tag_from_info():
tag_info = TagInfo(value="Myotis myotis", key="call_type")
tag = terms.get_tag_from_info(tag_info)
assert tag.value == "Myotis myotis"
assert tag.term == terms.call_type
def test_get_tag_from_info_key_not_found():
tag_info = TagInfo(value="test", key="non_existent_key")
with pytest.raises(KeyError):
terms.get_tag_from_info(tag_info)
def test_load_terms_from_config(tmp_path):
config_data = {
"terms": [
{
"key": "species",
"name": "dwc:scientificName",
"label": "Scientific Name",
},
{
"key": "my_custom_term",
"name": "soundevent:custom_term",
"definition": "Describes a specific project attribute",
},
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loaded_terms = load_terms_from_config(config_file)
assert "species" in loaded_terms
assert "my_custom_term" in loaded_terms
assert loaded_terms["species"].name == "dwc:scientificName"
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
def test_load_terms_from_config_file_not_found():
with pytest.raises(FileNotFoundError):
load_terms_from_config("non_existent_file.yaml")
def test_load_terms_from_config_validation_error(tmp_path):
config_data = {
"terms": [
{
"key": "species",
"uri": "dwc:scientificName",
"label": 123,
}, # Invalid label type
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
with pytest.raises(ValueError):
load_terms_from_config(config_file)
def test_load_terms_from_config_key_already_exists(tmp_path):
config_data = {
"terms": [
{
"key": "call_type",
"uri": "dwc:scientificName",
"label": "Scientific Name",
}, # Duplicate key
]
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
with pytest.raises(KeyError):
load_terms_from_config(config_file)

View File

@ -30,8 +30,18 @@ def test_mix_examples(
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()
example1 = generate_train_example(clip_annotation_1, config) example1 = generate_train_example(
example2 = generate_train_example(clip_annotation_2, config) clip_annotation_1,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
example2 = generate_train_example(
clip_annotation_2,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
mixed = mix_examples(example1, example2, config=config.preprocessing) mixed = mix_examples(example1, example2, config=config.preprocessing)
@ -59,8 +69,18 @@ def test_mix_examples_of_different_durations(
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()
example1 = generate_train_example(clip_annotation_1, config) example1 = generate_train_example(
example2 = generate_train_example(clip_annotation_2, config) clip_annotation_1,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
example2 = generate_train_example(
clip_annotation_2,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
mixed = mix_examples(example1, example2, config=config.preprocessing) mixed = mix_examples(example1, example2, config=config.preprocessing)
@ -78,7 +98,12 @@ def test_add_echo(
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()
original = generate_train_example(clip_annotation_1, config) original = generate_train_example(
clip_annotation_1,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
with_echo = add_echo(original, config=config.preprocessing) with_echo = add_echo(original, config=config.preprocessing)
assert with_echo["spectrogram"].shape == original["spectrogram"].shape assert with_echo["spectrogram"].shape == original["spectrogram"].shape
@ -94,7 +119,12 @@ def test_selected_random_subclip_has_the_correct_width(
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()
original = generate_train_example(clip_annotation_1, config) original = generate_train_example(
clip_annotation_1,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
subclip = select_subclip(original, width=100) subclip = select_subclip(original, width=100)
assert subclip["spectrogram"].shape[1] == 100 assert subclip["spectrogram"].shape[1] == 100
@ -107,7 +137,12 @@ def test_add_echo_after_subclip(
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1) clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()
original = generate_train_example(clip_annotation_1, config) original = generate_train_example(
clip_annotation_1,
preprocessing_config=config.preprocessing,
target_config=config.target,
label_config=config.labels,
)
assert original.sizes["time"] > 512 assert original.sizes["time"] > 512

View File

@ -4,7 +4,7 @@ import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.train.labels import generate_heatmaps from batdetect2.targets import generate_heatmaps
recording = data.Recording( recording = data.Recording(
samplerate=256_000, samplerate=256_000,