mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Create separate targets module
This commit is contained in:
parent
acda71ea45
commit
b93d4c65c2
@ -8,9 +8,11 @@ from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.preprocess import (
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
from batdetect2.targets import (
|
||||
load_label_config,
|
||||
load_target_config,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
preprocess_annotations,
|
||||
)
|
||||
|
||||
|
@ -12,10 +12,13 @@ from batdetect2.preprocess import (
|
||||
STFTConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||
from batdetect2.terms import TagInfo
|
||||
from batdetect2.train.preprocess import (
|
||||
from batdetect2.targets import (
|
||||
HeatmapsConfig,
|
||||
TagInfo,
|
||||
TargetConfig,
|
||||
)
|
||||
from batdetect2.targets.labels import LabelConfig
|
||||
from batdetect2.train.preprocess import (
|
||||
TrainPreprocessingConfig,
|
||||
)
|
||||
|
||||
@ -91,11 +94,13 @@ def get_training_preprocessing_config(
|
||||
for value in params["classes_to_ignore"]
|
||||
],
|
||||
),
|
||||
heatmaps=HeatmapsConfig(
|
||||
position="bottom-left",
|
||||
time_scale=1 / time_bin_width,
|
||||
frequency_scale=1 / freq_bin_width,
|
||||
sigma=params["target_sigma"],
|
||||
labels=LabelConfig(
|
||||
heatmaps=HeatmapsConfig(
|
||||
position="bottom-left",
|
||||
time_scale=1 / time_bin_width,
|
||||
frequency_scale=1 / freq_bin_width,
|
||||
sigma=params["target_sigma"],
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -1,12 +1,18 @@
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
AOEFAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
BatDetect2MergedAnnotations,
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.data import load_dataset, load_dataset_from_config
|
||||
from batdetect2.data.types import Dataset
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"AnnotatedDataset",
|
||||
"BatDetect2FilesAnnotations",
|
||||
"BatDetect2MergedAnnotations",
|
||||
"Dataset",
|
||||
"load_annotated_dataset",
|
||||
"load_dataset",
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
@ -32,5 +31,3 @@ AnnotationFormats = Union[
|
||||
BatDetect2AnnotationFile,
|
||||
AOEFAnnotationFile,
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,11 +1,9 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"BatDetect2MergedAnnotations",
|
||||
]
|
||||
|
||||
|
||||
@ -37,5 +35,3 @@ class AnnotatedDataset(BaseConfig):
|
||||
name: str
|
||||
audio_dir: Path
|
||||
description: str = ""
|
||||
|
||||
|
||||
|
@ -0,0 +1,9 @@
|
||||
from batdetect2.evaluate.evaluate import (
|
||||
compute_error_auc,
|
||||
match_predictions_and_annotations,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"compute_error_auc",
|
||||
"match_predictions_and_annotations",
|
||||
]
|
@ -1,13 +1,10 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
|
||||
from batdetect2.train.targets import build_target_encoder, get_class_names
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
@ -51,20 +48,13 @@ def match_predictions_and_annotations(
|
||||
return matches
|
||||
|
||||
|
||||
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
||||
ret = []
|
||||
|
||||
for match in matches:
|
||||
pass
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
# classification error
|
||||
pred_int = (pred > prob).astype(np.int32)
|
||||
class_acc = (pred_int == gt).mean() * 100.0
|
||||
|
||||
# ROC - area under curve
|
||||
fpr, tpr, thresholds = roc_curve(gt, pred)
|
||||
fpr, tpr, _ = roc_curve(gt, pred)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
print(
|
||||
@ -177,7 +167,7 @@ def compute_pre_rec(
|
||||
file_ids.append([pid] * valid_inds.sum())
|
||||
|
||||
confidence = np.hstack(confidence)
|
||||
file_ids = np.hstack(file_ids).astype(np.int)
|
||||
file_ids = np.hstack(file_ids).astype(int)
|
||||
pred_boxes = np.vstack(pred_boxes)
|
||||
if len(pred_class) > 0:
|
||||
pred_class = np.hstack(pred_class)
|
||||
@ -197,8 +187,7 @@ def compute_pre_rec(
|
||||
|
||||
# note, files with the incorrect duration will cause a problem
|
||||
if (gg["start_times"] > file_dur).sum() > 0:
|
||||
print("Error: file duration incorrect for", gg["id"])
|
||||
assert False
|
||||
raise ValueError(f"Error: file duration incorrect for {gg['id']}")
|
||||
|
||||
boxes = np.vstack(
|
||||
(
|
||||
@ -244,6 +233,8 @@ def compute_pre_rec(
|
||||
gt_id = file_ids[ind]
|
||||
|
||||
valid_det = False
|
||||
det_ind = 0
|
||||
|
||||
if gt_boxes[gt_id].shape[0] > 0:
|
||||
# compute overlap
|
||||
valid_det, det_ind = compute_affinity_1d(
|
||||
@ -273,7 +264,7 @@ def compute_pre_rec(
|
||||
# store threshold values - used for plotting
|
||||
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
||||
thresholds = np.linspace(0.1, 0.9, 9)
|
||||
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
|
||||
thresholds_inds = np.zeros(len(thresholds), dtype=int)
|
||||
for ii, tt in enumerate(thresholds):
|
||||
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
||||
thresholds_inds[thresholds_inds == 0] = -1
|
||||
@ -385,7 +376,7 @@ def compute_file_accuracy(gts, preds, num_classes):
|
||||
).mean(0)
|
||||
best_thresh = np.argmax(acc_per_thresh)
|
||||
best_acc = acc_per_thresh[best_thresh]
|
||||
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
|
||||
pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist()
|
||||
|
||||
res = {}
|
||||
res["num_valid_files"] = len(gt_valid)
|
||||
|
@ -1,92 +1,26 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
Net2DPlain,
|
||||
)
|
||||
from batdetect2.models.build import build_architecture
|
||||
from batdetect2.models.config import ModelConfig, ModelType, load_model_config
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
from batdetect2.models.types import BackboneModel, ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"BackboneModel",
|
||||
"ClassifierHead",
|
||||
"ModelConfig",
|
||||
"ModelOutput",
|
||||
"ModelType",
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"Net2DPlain",
|
||||
"build_architecture",
|
||||
"build_architecture",
|
||||
"load_model_config",
|
||||
]
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
||||
|
||||
|
||||
def build_architecture(
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
||||
if config.name == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
||||
|
@ -12,12 +12,13 @@ from batdetect2.models.blocks import (
|
||||
UpscalingLayer,
|
||||
VerticalConv,
|
||||
)
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"Net2DPlain",
|
||||
]
|
||||
|
||||
|
||||
@ -165,7 +166,6 @@ def pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
print(spec.shape)
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
|
58
batdetect2/models/build.py
Normal file
58
batdetect2/models/build.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Optional
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
Net2DPlain,
|
||||
)
|
||||
from batdetect2.models.config import ModelConfig, ModelType
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"build_architecture",
|
||||
]
|
||||
|
||||
|
||||
def build_architecture(
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
||||
if config.name == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
34
batdetect2/models/config.py
Normal file
34
batdetect2/models/config.py
Normal file
@ -0,0 +1,34 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"ModelType",
|
||||
"ModelConfig",
|
||||
"load_model_config",
|
||||
]
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
@ -1,15 +0,0 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
@ -1,15 +0,0 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
@ -7,31 +7,27 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.models import (
|
||||
BBoxHead,
|
||||
ClassifierHead,
|
||||
ModelConfig,
|
||||
ModelOutput,
|
||||
build_architecture,
|
||||
)
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.targets import (
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_target_encoder,
|
||||
get_class_names,
|
||||
)
|
||||
from batdetect2.train import TrainExample, TrainingConfig, compute_loss
|
||||
|
||||
__all__ = [
|
||||
"DetectorModel",
|
||||
@ -83,12 +79,9 @@ class DetectorModel(L.LightningModule):
|
||||
replacement_rules=self.config.targets.replace,
|
||||
)
|
||||
self.decoder = build_decoder(self.config.targets.classes)
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
||||
size_preds = self.bbox(features)
|
||||
@ -130,27 +123,6 @@ class DetectorModel(L.LightningModule):
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
clip_prediction = postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip_annotation.clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
matches = match_predictions_and_annotations(
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
)
|
||||
|
||||
self.validation_predictions.extend(matches)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.validation_predictions.clear()
|
||||
|
||||
|
@ -9,7 +9,7 @@ from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.models.types import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
|
0
batdetect2/postprocess/__init__.py
Normal file
0
batdetect2/postprocess/__init__.py
Normal file
73
batdetect2/postprocess/arrays.py
Normal file
73
batdetect2/postprocess/arrays.py
Normal file
@ -0,0 +1,73 @@
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.models import ModelOutput
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
|
||||
|
||||
def to_xarray(
|
||||
output: ModelOutput,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
class_names: list[str],
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
):
|
||||
detection = output.detection_probs
|
||||
size = output.size_preds
|
||||
classes = output.class_probs
|
||||
features = output.features
|
||||
|
||||
if len(detection.shape) == 4:
|
||||
if detection.shape[0] != 1:
|
||||
raise ValueError(
|
||||
"Expected a non-batched output or a batch of size 1, instead "
|
||||
f"got an input of shape {detection.shape}"
|
||||
)
|
||||
|
||||
detection = detection.squeeze(dim=0)
|
||||
size = size.squeeze(dim=0)
|
||||
classes = classes.squeeze(dim=0)
|
||||
features = features.squeeze(dim=0)
|
||||
|
||||
_, width, height = detection.shape
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
if classes.shape[0] != len(class_names):
|
||||
raise ValueError(
|
||||
f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })"
|
||||
)
|
||||
|
||||
return xr.Dataset(
|
||||
data_vars={
|
||||
"detection": (
|
||||
[Dimensions.time.value, Dimensions.frequency.value],
|
||||
detection.squeeze(dim=0).detach().numpy(),
|
||||
),
|
||||
"size": (
|
||||
[
|
||||
"dimension",
|
||||
Dimensions.time.value,
|
||||
Dimensions.frequency.value,
|
||||
],
|
||||
detection.detach().numpy(),
|
||||
),
|
||||
"classes": (
|
||||
[
|
||||
"category",
|
||||
Dimensions.time.value,
|
||||
Dimensions.frequency.value,
|
||||
],
|
||||
classes.detach().numpy(),
|
||||
),
|
||||
},
|
||||
coords={
|
||||
Dimensions.time.value: times,
|
||||
Dimensions.frequency.value: freqs,
|
||||
"dimension": ["width", "height"],
|
||||
"category": class_names,
|
||||
},
|
||||
)
|
32
batdetect2/postprocess/config.py
Normal file
32
batdetect2/postprocess/config.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
]
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration for postprocessing model outputs."""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
||||
min_freq: int = Field(default=10000, gt=0)
|
||||
max_freq: int = Field(default=120000, gt=0)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
50
batdetect2/postprocess/non_max_supression.py
Normal file
50
batdetect2/postprocess/non_max_supression.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Run non-maximum suppression on a tensor.
|
||||
|
||||
This function removes values from the input tensor that are not local
|
||||
maxima in the neighborhood of the given kernel size.
|
||||
|
||||
All non-maximum values are set to zero.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor.
|
||||
kernel_size : Union[int, Tuple[int, int]], optional
|
||||
Size of the neighborhood to consider for non-maximum suppression.
|
||||
If an integer is given, the neighborhood will be a square of the
|
||||
given size. If a tuple is given, the neighborhood will be a
|
||||
rectangle with the given height and width.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Tensor with non-maximum suppressed values.
|
||||
"""
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
||||
hmax = torch.nn.functional.max_pool2d(
|
||||
tensor,
|
||||
(kernel_size_h, kernel_size_w),
|
||||
stride=1,
|
||||
padding=(pad_h, pad_w),
|
||||
)
|
||||
keep = (hmax == tensor).float()
|
||||
return tensor * keep
|
17
batdetect2/postprocess/types.py
Normal file
17
batdetect2/postprocess/types.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Dict, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2Prediction",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2Prediction(NamedTuple):
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
detection_score: float
|
||||
class_scores: Dict[str, float]
|
||||
features: np.ndarray
|
@ -15,6 +15,8 @@ from batdetect2.preprocess.config import (
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
AmplitudeScaleConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
@ -31,6 +33,8 @@ __all__ = [
|
||||
"AudioConfig",
|
||||
"FrequencyConfig",
|
||||
"LogScaleConfig",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"PcenScaleConfig",
|
||||
"PreprocessingConfig",
|
||||
"ResampleConfig",
|
||||
|
@ -31,7 +31,7 @@ def load_file_audio(
|
||||
path: data.PathLike,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
recording = data.Recording.from_file(path)
|
||||
return load_recording_audio(
|
||||
@ -46,7 +46,7 @@ def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
@ -65,7 +65,7 @@ def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
config = config or AudioConfig()
|
||||
|
||||
@ -122,7 +122,7 @@ def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
mode: str = "poly",
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
@ -11,6 +11,20 @@ from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"LogScaleConfig",
|
||||
"PcenScaleConfig",
|
||||
"AmplitudeScaleConfig",
|
||||
"Scales",
|
||||
"SpectrogramConfig",
|
||||
"compute_spectrogram",
|
||||
]
|
||||
|
||||
MIN_FREQ = 10_000
|
||||
MAX_FREQ = 120_000
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
@ -70,7 +84,7 @@ class SpectrogramConfig(BaseConfig):
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
config: Optional[SpectrogramConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
@ -124,7 +138,7 @@ def stft(
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
@ -190,7 +204,7 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: Scales,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
if scale.name == "log":
|
||||
return scale_log(spec, dtype=dtype)
|
||||
@ -230,7 +244,7 @@ def scale_pcen(
|
||||
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
|
49
batdetect2/targets/__init__.py
Normal file
49
batdetect2/targets/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Module that helps define what the training targets are.
|
||||
|
||||
The goal of this module is to configure how raw sound event annotations (tags)
|
||||
are processed to determine which events are relevant for training and what
|
||||
specific class label each relevant event should receive. Also, define how
|
||||
predicted class labels map back to the original tag system for interpretation.
|
||||
"""
|
||||
|
||||
from batdetect2.targets.labels import (
|
||||
HeatmapsConfig,
|
||||
LabelConfig,
|
||||
generate_heatmaps,
|
||||
load_label_config,
|
||||
)
|
||||
from batdetect2.targets.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_sound_event_filter,
|
||||
build_target_encoder,
|
||||
filter_sound_event,
|
||||
get_class_names,
|
||||
load_target_config,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermInfo,
|
||||
call_type,
|
||||
get_tag_from_info,
|
||||
individual,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"TagInfo",
|
||||
"TargetConfig",
|
||||
"TermInfo",
|
||||
"build_decoder",
|
||||
"build_sound_event_filter",
|
||||
"build_target_encoder",
|
||||
"call_type",
|
||||
"filter_sound_event",
|
||||
"generate_heatmaps",
|
||||
"get_class_names",
|
||||
"get_tag_from_info",
|
||||
"individual",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
]
|
@ -7,7 +7,7 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
from batdetect2.targets.terms import TagInfo, get_tag_from_info
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
370
batdetect2/targets/terms.py
Normal file
370
batdetect2/targets/terms.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
||||
|
||||
This module provides the necessary tools to declare, register, and manage the
|
||||
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
||||
sub-package. It establishes a consistent vocabulary for filtering,
|
||||
transforming, and classifying sound events based on their annotations (Tags).
|
||||
|
||||
The core component is the `TermRegistry`, which maps unique string keys
|
||||
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
||||
terms using simple, consistent keys in configuration files and code.
|
||||
|
||||
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
||||
programmatically, or loaded from external configuration files (e.g., YAML).
|
||||
"""
|
||||
|
||||
from inspect import getmembers
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
# The default key used to reference the 'generic_class' term.
|
||||
# Often used implicitly when defining classification targets.
|
||||
GENERIC_CLASS_KEY = "class"
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
name="soundevent:call_type",
|
||||
label="Call Type",
|
||||
definition=(
|
||||
"A broad categorization of animal vocalizations based on their "
|
||||
"intended function or purpose (e.g., social, distress, mating, "
|
||||
"territorial, echolocation)."
|
||||
),
|
||||
)
|
||||
"""Term representing the broad functional category of a vocalization."""
|
||||
|
||||
individual = data.Term(
|
||||
name="soundevent:individual",
|
||||
label="Individual",
|
||||
definition=(
|
||||
"An id for an individual animal. In the context of bioacoustic "
|
||||
"annotation, this term is used to label vocalizations that are "
|
||||
"attributed to a specific individual."
|
||||
),
|
||||
)
|
||||
"""Term used for tags identifying a specific individual animal."""
|
||||
|
||||
generic_class = data.Term(
|
||||
name="soundevent:class",
|
||||
label="Class",
|
||||
definition=(
|
||||
"A generic term representing the name of a class within a "
|
||||
"classification model. Its specific meaning is determined by "
|
||||
"the model's application."
|
||||
),
|
||||
)
|
||||
"""Generic term representing a classification model's output class label."""
|
||||
|
||||
|
||||
class TermRegistry:
|
||||
"""Manages a registry mapping unique keys to Term definitions.
|
||||
|
||||
This class acts as the central repository for the vocabulary of terms
|
||||
used within the target definition process. It allows registering terms
|
||||
with simple string keys and retrieving them consistently.
|
||||
"""
|
||||
|
||||
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
||||
"""Initializes the TermRegistry.
|
||||
|
||||
Args:
|
||||
terms: An optional dictionary of initial key-to-Term mappings
|
||||
to populate the registry with. Defaults to an empty registry.
|
||||
"""
|
||||
self._terms = terms or {}
|
||||
|
||||
def add_term(self, key: str, term: data.Term) -> None:
|
||||
"""Adds a Term object to the registry with the specified key.
|
||||
|
||||
Args:
|
||||
key: The unique string key to associate with the term.
|
||||
term: The soundevent.data.Term object to register.
|
||||
|
||||
Raises:
|
||||
KeyError: If a term with the provided key already exists in the
|
||||
registry.
|
||||
"""
|
||||
if key in self._terms:
|
||||
raise KeyError("A term with the provided key already exists.")
|
||||
|
||||
self._terms[key] = term
|
||||
|
||||
def get_term(self, key: str) -> data.Term:
|
||||
"""Retrieves a registered term by its unique key.
|
||||
|
||||
Args:
|
||||
key: The unique string key of the term to retrieve.
|
||||
|
||||
Returns:
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises:
|
||||
KeyError: If no term with the specified key is found, with a
|
||||
helpful message suggesting listing available keys.
|
||||
"""
|
||||
try:
|
||||
return self._terms[key]
|
||||
except KeyError as err:
|
||||
raise KeyError(
|
||||
"No term found for key "
|
||||
f"'{key}'. Ensure it is registered or loaded. "
|
||||
"Use `get_term_keys()` to list available terms."
|
||||
) from err
|
||||
|
||||
def add_custom_term(
|
||||
self,
|
||||
key: str,
|
||||
name: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
definition: Optional[str] = None,
|
||||
) -> data.Term:
|
||||
"""Creates a new Term from attributes and adds it to the registry.
|
||||
|
||||
This is useful for defining terms directly in code or when loading
|
||||
from configuration files where only attributes are provided.
|
||||
|
||||
If optional fields (`name`, `label`, `definition`) are not provided,
|
||||
reasonable defaults are used (`key` for name/label, "Unknown" for
|
||||
definition).
|
||||
|
||||
Args:
|
||||
key: The unique string key for the new term.
|
||||
name: The name for the new term (defaults to `key`).
|
||||
uri: The URI for the new term (optional).
|
||||
label: The display label for the new term (defaults to `key`).
|
||||
definition: The definition for the new term (defaults to "Unknown").
|
||||
|
||||
Returns:
|
||||
The newly created and registered soundevent.data.Term object.
|
||||
|
||||
Raises:
|
||||
KeyError: If a term with the provided key already exists.
|
||||
"""
|
||||
term = data.Term(
|
||||
name=name or key,
|
||||
label=label or key,
|
||||
uri=uri,
|
||||
definition=definition or "Unknown",
|
||||
)
|
||||
self.add_term(key, term)
|
||||
return term
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""Returns a list of all keys currently registered.
|
||||
|
||||
Returns:
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return list(self._terms.keys())
|
||||
|
||||
|
||||
registry = TermRegistry(
|
||||
terms=dict(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
("call_type", call_type),
|
||||
("individual", individual),
|
||||
(GENERIC_CLASS_KEY, generic_class),
|
||||
]
|
||||
)
|
||||
)
|
||||
"""The default, globally accessible TermRegistry instance.
|
||||
|
||||
It is pre-populated with standard terms from `soundevent.terms` and common
|
||||
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
||||
Functions in this module use this registry by default unless another instance
|
||||
is explicitly passed.
|
||||
"""
|
||||
|
||||
|
||||
def get_term_from_key(
|
||||
key: str,
|
||||
term_registry: TermRegistry = registry,
|
||||
) -> data.Term:
|
||||
"""Convenience function to retrieve a term by key from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Args:
|
||||
key: The unique key of the term to retrieve.
|
||||
term_registry: The TermRegistry instance to search in. Defaults to
|
||||
the global `registry`.
|
||||
|
||||
Returns:
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises:
|
||||
KeyError: If the key is not found in the specified registry.
|
||||
"""
|
||||
return term_registry.get_term(key)
|
||||
|
||||
|
||||
def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
|
||||
"""Convenience function to get all registered keys from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Args:
|
||||
term_registry: The TermRegistry instance to query. Defaults to
|
||||
the global `registry`.
|
||||
|
||||
Returns:
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return term_registry.get_keys()
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
"""Represents information needed to define a specific Tag.
|
||||
|
||||
This model is typically used in configuration files (e.g., YAML) to
|
||||
specify tags used for filtering, target class definition, or associating
|
||||
tags with output classes. It links a tag value to a term definition
|
||||
via the term's registry key.
|
||||
|
||||
Attributes:
|
||||
value: The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
||||
key: The key (alias) of the term associated with this tag, as
|
||||
registered in the TermRegistry. Defaults to "class", implying
|
||||
it represents a classification target label by default.
|
||||
"""
|
||||
|
||||
value: str
|
||||
key: str = GENERIC_CLASS_KEY
|
||||
|
||||
|
||||
def get_tag_from_info(
|
||||
tag_info: TagInfo,
|
||||
term_registry: TermRegistry = registry,
|
||||
) -> data.Tag:
|
||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||
|
||||
Looks up the term using the key in the provided `tag_info` from the
|
||||
specified registry and constructs a Tag object.
|
||||
|
||||
Args:
|
||||
tag_info: The TagInfo object containing the value and term key.
|
||||
term_registry: The TermRegistry instance to use for term lookup.
|
||||
Defaults to the global `registry`.
|
||||
|
||||
Returns:
|
||||
A soundevent.data.Tag object corresponding to the input info.
|
||||
|
||||
Raises:
|
||||
KeyError: If the term key specified in `tag_info.key` is not found
|
||||
in the registry.
|
||||
"""
|
||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||
return data.Tag(term=term, value=tag_info.value)
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
"""Represents the definition of a Term within a configuration file.
|
||||
|
||||
This model allows users to define custom terms directly in configuration
|
||||
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
||||
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
||||
|
||||
Attributes:
|
||||
key: The unique key (alias) that will be used to register and
|
||||
reference this term.
|
||||
label: The optional display label for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
name: The optional formal name for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
uri: The optional URI identifying the term (e.g., from a standard
|
||||
vocabulary).
|
||||
definition: The optional textual definition of the term. Defaults to
|
||||
"Unknown" if not provided during registration.
|
||||
"""
|
||||
|
||||
key: str
|
||||
label: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
uri: Optional[str] = None
|
||||
definition: Optional[str] = None
|
||||
|
||||
|
||||
class TermConfig(BaseModel):
|
||||
"""Pydantic schema for loading a list of term definitions from config.
|
||||
|
||||
This model typically corresponds to a section in a configuration file
|
||||
(e.g., YAML) containing a list of term definitions to be registered.
|
||||
|
||||
Example YAML structure:
|
||||
```yaml
|
||||
terms:
|
||||
- key: species
|
||||
uri: dwc:scientificName
|
||||
label: Scientific Name
|
||||
- key: my_custom_term
|
||||
name: My Custom Term
|
||||
definition: Describes a specific project attribute.
|
||||
# ... more TermInfo definitions
|
||||
```
|
||||
|
||||
Attributes:
|
||||
terms: A list of TermInfo objects, each defining a term to be
|
||||
registered. Defaults to an empty list.
|
||||
"""
|
||||
|
||||
terms: List[TermInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
def load_terms_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = registry,
|
||||
) -> Dict[str, data.Term]:
|
||||
"""Loads term definitions from a configuration file and registers them.
|
||||
|
||||
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
||||
extracts the list of TermInfo definitions, and adds each one as a
|
||||
custom term to the specified TermRegistry instance.
|
||||
|
||||
Args:
|
||||
path: The path to the configuration file.
|
||||
field: Optional key indicating a specific section within the config
|
||||
file where the 'terms' list is located. If None, expects the
|
||||
list directly at the top level or within a structure matching
|
||||
TermConfig schema.
|
||||
term_registry: The TermRegistry instance to add the loaded terms to.
|
||||
Defaults to the global `registry`.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping the keys of the newly added terms to their
|
||||
corresponding Term objects.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the config file path does not exist.
|
||||
ValidationError: If the config file structure does not match the
|
||||
TermConfig schema.
|
||||
KeyError: If a term key loaded from the config conflicts with a key
|
||||
already present in the registry.
|
||||
"""
|
||||
data = load_config(path, schema=TermConfig, field=field)
|
||||
return {
|
||||
info.key: term_registry.add_custom_term(
|
||||
info.key,
|
||||
name=info.name,
|
||||
uri=info.uri,
|
||||
label=info.label,
|
||||
definition=info.definition,
|
||||
)
|
||||
for info in data.terms
|
||||
}
|
@ -1,88 +0,0 @@
|
||||
from inspect import getmembers
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from soundevent import data, terms
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"get_term_from_info",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
label: Optional[str]
|
||||
name: Optional[str]
|
||||
uri: Optional[str]
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
value: str
|
||||
term: Optional[TermInfo] = None
|
||||
key: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
name="soundevent:call_type",
|
||||
label="Call Type",
|
||||
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
||||
)
|
||||
|
||||
individual = data.Term(
|
||||
name="soundevent:individual",
|
||||
label="Individual",
|
||||
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
||||
)
|
||||
|
||||
|
||||
ALL_TERMS = [
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
call_type,
|
||||
individual,
|
||||
]
|
||||
|
||||
|
||||
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
||||
for term in ALL_TERMS:
|
||||
if term_info.name and term_info.name == term.name:
|
||||
return term
|
||||
|
||||
if term_info.label and term_info.label == term.label:
|
||||
return term
|
||||
|
||||
if term_info.uri and term_info.uri == term.uri:
|
||||
return term
|
||||
|
||||
if term_info.name is None:
|
||||
if term_info.label is None:
|
||||
raise ValueError("At least one of name or label must be provided.")
|
||||
|
||||
term_info.name = (
|
||||
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
||||
)
|
||||
|
||||
if term_info.label is None:
|
||||
term_info.label = term_info.name
|
||||
|
||||
return data.Term(
|
||||
name=term_info.name,
|
||||
label=term_info.label,
|
||||
uri=term_info.uri,
|
||||
definition="Unknown",
|
||||
)
|
||||
|
||||
|
||||
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||
if tag_info.term:
|
||||
term = get_term_from_info(tag_info.term)
|
||||
elif tag_info.key:
|
||||
term = data.term_from_key(tag_info.key)
|
||||
else:
|
||||
raise ValueError("Either term or key must be provided in tag info.")
|
||||
|
||||
return data.Tag(term=term, value=tag_info.value)
|
@ -17,37 +17,26 @@ from batdetect2.train.dataset import (
|
||||
TrainExample,
|
||||
get_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.preprocess import (
|
||||
generate_train_example,
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.train.targets import (
|
||||
TagInfo,
|
||||
TargetConfig,
|
||||
build_target_encoder,
|
||||
load_target_config,
|
||||
)
|
||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"LabelConfig",
|
||||
"LabeledDataset",
|
||||
"SubclipConfig",
|
||||
"TagInfo",
|
||||
"TargetConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"add_echo",
|
||||
"augment_example",
|
||||
"build_target_encoder",
|
||||
"compute_loss",
|
||||
"generate_train_example",
|
||||
"get_preprocessed_files",
|
||||
"load_agumentation_config",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
"load_train_config",
|
||||
"load_trainer_config",
|
||||
"mask_frequency",
|
||||
|
55
batdetect2/train/callbacks.py
Normal file
55
batdetect2/train/callbacks.py
Normal file
@ -0,0 +1,55 @@
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.post_process import postprocess_model_outputs
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.types import ModelOutput
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.predictions = []
|
||||
|
||||
def on_validation_epoch_start(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
self.predictions = []
|
||||
return super().on_validation_epoch_start(trainer, pl_module)
|
||||
|
||||
def on_validation_batch_end( # type: ignore
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
clip_prediction = postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip_annotation.clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
matches = match_predictions_and_annotations(
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
)
|
||||
|
||||
self.validation_predictions.extend(matches)
|
||||
return super().on_validation_batch_end(
|
||||
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
)
|
@ -6,7 +6,7 @@ from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
__all__ = [
|
||||
|
@ -17,11 +17,12 @@ from batdetect2.preprocess import (
|
||||
compute_spectrogram,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||
from batdetect2.train.targets import (
|
||||
from batdetect2.targets import (
|
||||
LabelConfig,
|
||||
TargetConfig,
|
||||
build_target_encoder,
|
||||
build_sound_event_filter,
|
||||
build_target_encoder,
|
||||
generate_heatmaps,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
|
@ -91,7 +91,15 @@ convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = ["batdetect2", "tests"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
exclude = [
|
||||
"batdetect2/detector/",
|
||||
"batdetect2/finetune",
|
||||
"batdetect2/utils",
|
||||
"batdetect2/plotting",
|
||||
"batdetect2/plot",
|
||||
"batdetect2/api",
|
||||
"batdetect2/evaluate/legacy",
|
||||
"batdetect2/train/legacy",
|
||||
]
|
||||
|
@ -1,11 +1,18 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
build_target_encoder,
|
||||
call_type,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -107,3 +114,102 @@ def recording_factory(wav_factory: Callable[..., Path]):
|
||||
)
|
||||
|
||||
return _recording_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recording(
|
||||
recording_factory: Callable[..., data.Recording],
|
||||
) -> data.Recording:
|
||||
return recording_factory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clip(recording: data.Recording) -> data.Clip:
|
||||
return data.Clip(recording=recording, start_time=0, end_time=0.5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[
|
||||
data.Tag(term=terms.scientific_name, value="Myotis myotis"),
|
||||
data.Tag(term=call_type, value="Echolocation"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[0.34, 35_000, 0.348, 62_000]
|
||||
),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[
|
||||
data.Tag(term=terms.order, value="Chiroptera"),
|
||||
data.Tag(term=call_type, value="Echolocation"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_relevant_sound_event(
|
||||
recording: data.Recording,
|
||||
) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[0.22, 50_000, 0.24, 58_000]
|
||||
),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=terms.scientific_name,
|
||||
value="Muscardinus avellanarius",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clip_annotation(
|
||||
clip: data.Clip,
|
||||
echolocation_call: data.SoundEventAnnotation,
|
||||
generic_call: data.SoundEventAnnotation,
|
||||
non_relevant_sound_event: data.SoundEventAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
return data.ClipAnnotation(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
echolocation_call,
|
||||
generic_call,
|
||||
non_relevant_sound_event,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_config() -> TargetConfig:
|
||||
return TargetConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def class_names(target_config: TargetConfig) -> List[str]:
|
||||
return get_class_names(target_config.classes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def encoder(
|
||||
target_config: TargetConfig,
|
||||
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||
return build_target_encoder(
|
||||
classes=target_config.classes,
|
||||
replacement_rules=target_config.replace,
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.compat.data import load_annotation_project_from_dir
|
||||
from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
@ -12,8 +12,14 @@ ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
def test_load_example_annotation_project():
|
||||
path = ROOT_DIR / "example_data" / "anns"
|
||||
audio_dir = ROOT_DIR / "example_data" / "audio"
|
||||
project = load_annotation_project_from_dir(path, audio_dir=audio_dir)
|
||||
project = load_annotated_dataset(
|
||||
BatDetect2FilesAnnotations(
|
||||
name="test",
|
||||
audio_dir=audio_dir,
|
||||
annotations_dir=path,
|
||||
)
|
||||
)
|
||||
assert isinstance(project, data.AnnotationProject)
|
||||
assert project.name == str(path)
|
||||
assert project.name == "test"
|
||||
assert len(project.clip_annotations) == 3
|
||||
assert len(project.tasks) == 3
|
||||
|
@ -5,8 +5,8 @@ from typing import List
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from batdetect2.compat.data import load_annotation_project_from_dir
|
||||
from batdetect2.compat.params import get_training_preprocessing_config
|
||||
from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
|
||||
|
||||
@ -26,6 +26,8 @@ def test_can_generate_similar_training_inputs(
|
||||
old_parameters = json.loads((regression_dir / "params.json").read_text())
|
||||
config = get_training_preprocessing_config(old_parameters)
|
||||
|
||||
assert config is not None
|
||||
|
||||
for audio_file in example_audio_files:
|
||||
example_file = regression_dir / f"{audio_file.name}.npz"
|
||||
|
||||
@ -36,9 +38,12 @@ def test_can_generate_similar_training_inputs(
|
||||
size_mask = dataset["size_mask"]
|
||||
class_mask = dataset["class_mask"]
|
||||
|
||||
project = load_annotation_project_from_dir(
|
||||
example_anns_dir,
|
||||
audio_dir=example_audio_dir,
|
||||
project = load_annotated_dataset(
|
||||
BatDetect2FilesAnnotations(
|
||||
name="test",
|
||||
annotations_dir=example_anns_dir,
|
||||
audio_dir=example_audio_dir,
|
||||
)
|
||||
)
|
||||
|
||||
clip_annotation = next(
|
||||
@ -47,7 +52,12 @@ def test_can_generate_similar_training_inputs(
|
||||
if ann.clip.recording.path == audio_file
|
||||
)
|
||||
|
||||
new_dataset = generate_train_example(clip_annotation, config)
|
||||
new_dataset = generate_train_example(
|
||||
clip_annotation,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
new_spec = new_dataset["spectrogram"].values
|
||||
new_detection_mask = new_dataset["detection"].values
|
||||
new_size_mask = new_dataset["size"].values
|
||||
|
0
tests/test_postprocessing/__init__.py
Normal file
0
tests/test_postprocessing/__init__.py
Normal file
43
tests/test_postprocessing/test_arrays.py
Normal file
43
tests/test_postprocessing/test_arrays.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.modules import DetectorModel
|
||||
from batdetect2.postprocess.arrays import to_xarray
|
||||
from batdetect2.preprocess import preprocess_audio_clip
|
||||
|
||||
|
||||
def test_this(clip: data.Clip, class_names: List[str]):
|
||||
spec = xr.DataArray(
|
||||
data=np.random.rand(100, 100),
|
||||
dims=["time", "frequency"],
|
||||
coords={
|
||||
"time": np.linspace(0, 100, 100, endpoint=False),
|
||||
"frequency": np.linspace(0, 100, 100, endpoint=False),
|
||||
},
|
||||
)
|
||||
|
||||
model = DetectorModel()
|
||||
|
||||
spec = preprocess_audio_clip(
|
||||
clip,
|
||||
config=model.config.preprocessing,
|
||||
)
|
||||
|
||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
outputs = model(tensor)
|
||||
|
||||
arrays = to_xarray(
|
||||
outputs,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
print(arrays)
|
||||
|
||||
assert False
|
0
tests/test_targets/__init__.py
Normal file
0
tests/test_targets/__init__.py
Normal file
175
tests/test_targets/test_terms.py
Normal file
175
tests/test_targets/test_terms.py
Normal file
@ -0,0 +1,175 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
load_terms_from_config,
|
||||
)
|
||||
|
||||
|
||||
def test_term_registry_initialization():
|
||||
registry = TermRegistry()
|
||||
assert registry._terms == {}
|
||||
|
||||
initial_terms = {
|
||||
"test_term": data.Term(name="test", label="Test", definition="test")
|
||||
}
|
||||
registry = TermRegistry(terms=initial_terms)
|
||||
assert registry._terms == initial_terms
|
||||
|
||||
|
||||
def test_term_registry_add_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
assert registry._terms["test_key"] == term
|
||||
|
||||
|
||||
def test_term_registry_get_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
retrieved_term = registry.get_term("test_key")
|
||||
assert retrieved_term == term
|
||||
|
||||
|
||||
def test_term_registry_add_custom_term():
|
||||
registry = TermRegistry()
|
||||
term = registry.add_custom_term(
|
||||
"custom_key", name="custom", label="Custom", definition="A custom term"
|
||||
)
|
||||
assert registry._terms["custom_key"] == term
|
||||
assert term.name == "custom"
|
||||
assert term.label == "Custom"
|
||||
assert term.definition == "A custom term"
|
||||
|
||||
|
||||
def test_term_registry_add_duplicate_term():
|
||||
registry = TermRegistry()
|
||||
term = data.Term(name="test", label="Test", definition="test")
|
||||
registry.add_term("test_key", term)
|
||||
with pytest.raises(KeyError):
|
||||
registry.add_term("test_key", term)
|
||||
|
||||
|
||||
def test_term_registry_get_term_not_found():
|
||||
registry = TermRegistry()
|
||||
with pytest.raises(KeyError):
|
||||
registry.get_term("non_existent_key")
|
||||
|
||||
|
||||
def test_term_registry_get_keys():
|
||||
registry = TermRegistry()
|
||||
term1 = data.Term(name="test1", label="Test1", definition="test")
|
||||
term2 = data.Term(name="test2", label="Test2", definition="test")
|
||||
registry.add_term("key1", term1)
|
||||
registry.add_term("key2", term2)
|
||||
keys = registry.get_keys()
|
||||
assert set(keys) == {"key1", "key2"}
|
||||
|
||||
|
||||
def test_get_term_from_key():
|
||||
term = terms.get_term_from_key("call_type")
|
||||
assert term == terms.call_type
|
||||
|
||||
custom_registry = TermRegistry()
|
||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||
custom_registry.add_term("custom_key", custom_term)
|
||||
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
|
||||
assert term == custom_term
|
||||
|
||||
|
||||
def test_get_term_keys():
|
||||
keys = terms.get_term_keys()
|
||||
assert "call_type" in keys
|
||||
assert "individual" in keys
|
||||
assert terms.GENERIC_CLASS_KEY in keys
|
||||
|
||||
custom_registry = TermRegistry()
|
||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||
custom_registry.add_term("custom_key", custom_term)
|
||||
keys = terms.get_term_keys(term_registry=custom_registry)
|
||||
assert "custom_key" in keys
|
||||
|
||||
|
||||
def test_tag_info_and_get_tag_from_info():
|
||||
tag_info = TagInfo(value="Myotis myotis", key="call_type")
|
||||
tag = terms.get_tag_from_info(tag_info)
|
||||
assert tag.value == "Myotis myotis"
|
||||
assert tag.term == terms.call_type
|
||||
|
||||
|
||||
def test_get_tag_from_info_key_not_found():
|
||||
tag_info = TagInfo(value="test", key="non_existent_key")
|
||||
with pytest.raises(KeyError):
|
||||
terms.get_tag_from_info(tag_info)
|
||||
|
||||
|
||||
def test_load_terms_from_config(tmp_path):
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "species",
|
||||
"name": "dwc:scientificName",
|
||||
"label": "Scientific Name",
|
||||
},
|
||||
{
|
||||
"key": "my_custom_term",
|
||||
"name": "soundevent:custom_term",
|
||||
"definition": "Describes a specific project attribute",
|
||||
},
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
loaded_terms = load_terms_from_config(config_file)
|
||||
assert "species" in loaded_terms
|
||||
assert "my_custom_term" in loaded_terms
|
||||
assert loaded_terms["species"].name == "dwc:scientificName"
|
||||
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
|
||||
|
||||
|
||||
def test_load_terms_from_config_file_not_found():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_terms_from_config("non_existent_file.yaml")
|
||||
|
||||
|
||||
def test_load_terms_from_config_validation_error(tmp_path):
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "species",
|
||||
"uri": "dwc:scientificName",
|
||||
"label": 123,
|
||||
}, # Invalid label type
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_terms_from_config(config_file)
|
||||
|
||||
|
||||
def test_load_terms_from_config_key_already_exists(tmp_path):
|
||||
config_data = {
|
||||
"terms": [
|
||||
{
|
||||
"key": "call_type",
|
||||
"uri": "dwc:scientificName",
|
||||
"label": "Scientific Name",
|
||||
}, # Duplicate key
|
||||
]
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
load_terms_from_config(config_file)
|
@ -30,8 +30,18 @@ def test_mix_examples(
|
||||
|
||||
config = TrainPreprocessingConfig()
|
||||
|
||||
example1 = generate_train_example(clip_annotation_1, config)
|
||||
example2 = generate_train_example(clip_annotation_2, config)
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
|
||||
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
||||
|
||||
@ -59,8 +69,18 @@ def test_mix_examples_of_different_durations(
|
||||
|
||||
config = TrainPreprocessingConfig()
|
||||
|
||||
example1 = generate_train_example(clip_annotation_1, config)
|
||||
example2 = generate_train_example(clip_annotation_2, config)
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
|
||||
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
||||
|
||||
@ -78,7 +98,12 @@ def test_add_echo(
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
config = TrainPreprocessingConfig()
|
||||
original = generate_train_example(clip_annotation_1, config)
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
with_echo = add_echo(original, config=config.preprocessing)
|
||||
|
||||
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
||||
@ -94,7 +119,12 @@ def test_selected_random_subclip_has_the_correct_width(
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
config = TrainPreprocessingConfig()
|
||||
original = generate_train_example(clip_annotation_1, config)
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
subclip = select_subclip(original, width=100)
|
||||
|
||||
assert subclip["spectrogram"].shape[1] == 100
|
||||
@ -107,7 +137,12 @@ def test_add_echo_after_subclip(
|
||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
config = TrainPreprocessingConfig()
|
||||
original = generate_train_example(clip_annotation_1, config)
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
preprocessing_config=config.preprocessing,
|
||||
target_config=config.target,
|
||||
label_config=config.labels,
|
||||
)
|
||||
|
||||
assert original.sizes["time"] > 512
|
||||
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
from batdetect2.targets import generate_heatmaps
|
||||
|
||||
recording = data.Recording(
|
||||
samplerate=256_000,
|
||||
|
Loading…
Reference in New Issue
Block a user