mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Create separate targets module
This commit is contained in:
parent
acda71ea45
commit
b93d4c65c2
@ -8,9 +8,11 @@ from batdetect2.data import load_dataset_from_config
|
|||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
load_preprocessing_config,
|
load_preprocessing_config,
|
||||||
)
|
)
|
||||||
from batdetect2.train import (
|
from batdetect2.targets import (
|
||||||
load_label_config,
|
load_label_config,
|
||||||
load_target_config,
|
load_target_config,
|
||||||
|
)
|
||||||
|
from batdetect2.train import (
|
||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,10 +12,13 @@ from batdetect2.preprocess import (
|
|||||||
STFTConfig,
|
STFTConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||||
from batdetect2.terms import TagInfo
|
from batdetect2.targets import (
|
||||||
from batdetect2.train.preprocess import (
|
|
||||||
HeatmapsConfig,
|
HeatmapsConfig,
|
||||||
|
TagInfo,
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.labels import LabelConfig
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
TrainPreprocessingConfig,
|
TrainPreprocessingConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -91,11 +94,13 @@ def get_training_preprocessing_config(
|
|||||||
for value in params["classes_to_ignore"]
|
for value in params["classes_to_ignore"]
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
heatmaps=HeatmapsConfig(
|
labels=LabelConfig(
|
||||||
position="bottom-left",
|
heatmaps=HeatmapsConfig(
|
||||||
time_scale=1 / time_bin_width,
|
position="bottom-left",
|
||||||
frequency_scale=1 / freq_bin_width,
|
time_scale=1 / time_bin_width,
|
||||||
sigma=params["target_sigma"],
|
frequency_scale=1 / freq_bin_width,
|
||||||
|
sigma=params["target_sigma"],
|
||||||
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
from batdetect2.data.annotations import (
|
from batdetect2.data.annotations import (
|
||||||
AnnotatedDataset,
|
AnnotatedDataset,
|
||||||
|
AOEFAnnotations,
|
||||||
|
BatDetect2FilesAnnotations,
|
||||||
|
BatDetect2MergedAnnotations,
|
||||||
load_annotated_dataset,
|
load_annotated_dataset,
|
||||||
)
|
)
|
||||||
from batdetect2.data.data import load_dataset, load_dataset_from_config
|
from batdetect2.data.data import load_dataset, load_dataset_from_config
|
||||||
from batdetect2.data.types import Dataset
|
from batdetect2.data.types import Dataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AOEFAnnotations",
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
|
"BatDetect2FilesAnnotations",
|
||||||
|
"BatDetect2MergedAnnotations",
|
||||||
"Dataset",
|
"Dataset",
|
||||||
"load_annotated_dataset",
|
"load_annotated_dataset",
|
||||||
"load_dataset",
|
"load_dataset",
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
@ -32,5 +31,3 @@ AnnotationFormats = Union[
|
|||||||
BatDetect2AnnotationFile,
|
BatDetect2AnnotationFile,
|
||||||
AOEFAnnotationFile,
|
AOEFAnnotationFile,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Union
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
"BatDetect2MergedAnnotations",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -37,5 +35,3 @@ class AnnotatedDataset(BaseConfig):
|
|||||||
name: str
|
name: str
|
||||||
audio_dir: Path
|
audio_dir: Path
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
from batdetect2.evaluate.evaluate import (
|
||||||
|
compute_error_auc,
|
||||||
|
match_predictions_and_annotations,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"compute_error_auc",
|
||||||
|
"match_predictions_and_annotations",
|
||||||
|
]
|
@ -1,13 +1,10 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
from sklearn.metrics import auc, roc_curve
|
from sklearn.metrics import auc, roc_curve
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_geometries
|
from soundevent.evaluation import match_geometries
|
||||||
|
|
||||||
from batdetect2.train.targets import build_target_encoder, get_class_names
|
|
||||||
|
|
||||||
|
|
||||||
def match_predictions_and_annotations(
|
def match_predictions_and_annotations(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
@ -51,20 +48,13 @@ def match_predictions_and_annotations(
|
|||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def compute_error_auc(op_str, gt, pred, prob):
|
def compute_error_auc(op_str, gt, pred, prob):
|
||||||
# classification error
|
# classification error
|
||||||
pred_int = (pred > prob).astype(np.int32)
|
pred_int = (pred > prob).astype(np.int32)
|
||||||
class_acc = (pred_int == gt).mean() * 100.0
|
class_acc = (pred_int == gt).mean() * 100.0
|
||||||
|
|
||||||
# ROC - area under curve
|
# ROC - area under curve
|
||||||
fpr, tpr, thresholds = roc_curve(gt, pred)
|
fpr, tpr, _ = roc_curve(gt, pred)
|
||||||
roc_auc = auc(fpr, tpr)
|
roc_auc = auc(fpr, tpr)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
@ -177,7 +167,7 @@ def compute_pre_rec(
|
|||||||
file_ids.append([pid] * valid_inds.sum())
|
file_ids.append([pid] * valid_inds.sum())
|
||||||
|
|
||||||
confidence = np.hstack(confidence)
|
confidence = np.hstack(confidence)
|
||||||
file_ids = np.hstack(file_ids).astype(np.int)
|
file_ids = np.hstack(file_ids).astype(int)
|
||||||
pred_boxes = np.vstack(pred_boxes)
|
pred_boxes = np.vstack(pred_boxes)
|
||||||
if len(pred_class) > 0:
|
if len(pred_class) > 0:
|
||||||
pred_class = np.hstack(pred_class)
|
pred_class = np.hstack(pred_class)
|
||||||
@ -197,8 +187,7 @@ def compute_pre_rec(
|
|||||||
|
|
||||||
# note, files with the incorrect duration will cause a problem
|
# note, files with the incorrect duration will cause a problem
|
||||||
if (gg["start_times"] > file_dur).sum() > 0:
|
if (gg["start_times"] > file_dur).sum() > 0:
|
||||||
print("Error: file duration incorrect for", gg["id"])
|
raise ValueError(f"Error: file duration incorrect for {gg['id']}")
|
||||||
assert False
|
|
||||||
|
|
||||||
boxes = np.vstack(
|
boxes = np.vstack(
|
||||||
(
|
(
|
||||||
@ -244,6 +233,8 @@ def compute_pre_rec(
|
|||||||
gt_id = file_ids[ind]
|
gt_id = file_ids[ind]
|
||||||
|
|
||||||
valid_det = False
|
valid_det = False
|
||||||
|
det_ind = 0
|
||||||
|
|
||||||
if gt_boxes[gt_id].shape[0] > 0:
|
if gt_boxes[gt_id].shape[0] > 0:
|
||||||
# compute overlap
|
# compute overlap
|
||||||
valid_det, det_ind = compute_affinity_1d(
|
valid_det, det_ind = compute_affinity_1d(
|
||||||
@ -273,7 +264,7 @@ def compute_pre_rec(
|
|||||||
# store threshold values - used for plotting
|
# store threshold values - used for plotting
|
||||||
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
||||||
thresholds = np.linspace(0.1, 0.9, 9)
|
thresholds = np.linspace(0.1, 0.9, 9)
|
||||||
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
|
thresholds_inds = np.zeros(len(thresholds), dtype=int)
|
||||||
for ii, tt in enumerate(thresholds):
|
for ii, tt in enumerate(thresholds):
|
||||||
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
||||||
thresholds_inds[thresholds_inds == 0] = -1
|
thresholds_inds[thresholds_inds == 0] = -1
|
||||||
@ -385,7 +376,7 @@ def compute_file_accuracy(gts, preds, num_classes):
|
|||||||
).mean(0)
|
).mean(0)
|
||||||
best_thresh = np.argmax(acc_per_thresh)
|
best_thresh = np.argmax(acc_per_thresh)
|
||||||
best_acc = acc_per_thresh[best_thresh]
|
best_acc = acc_per_thresh[best_thresh]
|
||||||
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
|
pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist()
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
res["num_valid_files"] = len(gt_valid)
|
res["num_valid_files"] = len(gt_valid)
|
||||||
|
@ -1,92 +1,26 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Net2DFast,
|
Net2DFast,
|
||||||
Net2DFastNoAttn,
|
Net2DFastNoAttn,
|
||||||
Net2DFastNoCoordConv,
|
Net2DFastNoCoordConv,
|
||||||
Net2DPlain,
|
Net2DPlain,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models.build import build_architecture
|
||||||
|
from batdetect2.models.config import ModelConfig, ModelType, load_model_config
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.models.typing import BackboneModel
|
from batdetect2.models.types import BackboneModel, ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
|
"BackboneModel",
|
||||||
"ClassifierHead",
|
"ClassifierHead",
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
|
"ModelOutput",
|
||||||
"ModelType",
|
"ModelType",
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
|
"Net2DPlain",
|
||||||
|
"build_architecture",
|
||||||
"build_architecture",
|
"build_architecture",
|
||||||
"load_model_config",
|
"load_model_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
Net2DFast = "Net2DFast"
|
|
||||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
|
||||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
|
||||||
Net2DPlain = "Net2DPlain"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
|
||||||
name: ModelType = ModelType.Net2DFast
|
|
||||||
input_height: int = 128
|
|
||||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
|
||||||
bottleneck_channels: int = 256
|
|
||||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
|
||||||
out_channels: int = 32
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(
|
|
||||||
path: PathLike, field: Optional[str] = None
|
|
||||||
) -> ModelConfig:
|
|
||||||
return load_config(path, schema=ModelConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def build_architecture(
|
|
||||||
config: Optional[ModelConfig] = None,
|
|
||||||
) -> BackboneModel:
|
|
||||||
config = config or ModelConfig()
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DFast:
|
|
||||||
return Net2DFast(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DFastNoAttn:
|
|
||||||
return Net2DFastNoAttn(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
|
||||||
return Net2DFastNoCoordConv(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DPlain:
|
|
||||||
return Net2DPlain(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {config.name}")
|
|
||||||
|
@ -12,12 +12,13 @@ from batdetect2.models.blocks import (
|
|||||||
UpscalingLayer,
|
UpscalingLayer,
|
||||||
VerticalConv,
|
VerticalConv,
|
||||||
)
|
)
|
||||||
from batdetect2.models.typing import BackboneModel
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
|
"Net2DPlain",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -165,7 +166,6 @@ def pad_adjust(
|
|||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
factor: int = 32,
|
factor: int = 32,
|
||||||
) -> Tuple[torch.Tensor, int, int]:
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
print(spec.shape)
|
|
||||||
h, w = spec.shape[2:]
|
h, w = spec.shape[2:]
|
||||||
h_pad = -h % factor
|
h_pad = -h % factor
|
||||||
w_pad = -w % factor
|
w_pad = -w % factor
|
||||||
|
58
batdetect2/models/build.py
Normal file
58
batdetect2/models/build.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from batdetect2.models.backbones import (
|
||||||
|
Net2DFast,
|
||||||
|
Net2DFastNoAttn,
|
||||||
|
Net2DFastNoCoordConv,
|
||||||
|
Net2DPlain,
|
||||||
|
)
|
||||||
|
from batdetect2.models.config import ModelConfig, ModelType
|
||||||
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_architecture",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_architecture(
|
||||||
|
config: Optional[ModelConfig] = None,
|
||||||
|
) -> BackboneModel:
|
||||||
|
config = config or ModelConfig()
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DFast:
|
||||||
|
return Net2DFast(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DFastNoAttn:
|
||||||
|
return Net2DFastNoAttn(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||||
|
return Net2DFastNoCoordConv(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DPlain:
|
||||||
|
return Net2DPlain(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown model type: {config.name}")
|
34
batdetect2/models/config.py
Normal file
34
batdetect2/models/config.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelType",
|
||||||
|
"ModelConfig",
|
||||||
|
"load_model_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
Net2DFast = "Net2DFast"
|
||||||
|
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||||
|
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||||
|
Net2DPlain = "Net2DPlain"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseConfig):
|
||||||
|
name: ModelType = ModelType.Net2DFast
|
||||||
|
input_height: int = 128
|
||||||
|
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||||
|
bottleneck_channels: int = 256
|
||||||
|
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_config(
|
||||||
|
path: PathLike, field: Optional[str] = None
|
||||||
|
) -> ModelConfig:
|
||||||
|
return load_config(path, schema=ModelConfig, field=field)
|
@ -1,15 +0,0 @@
|
|||||||
import sys
|
|
||||||
from typing import Iterable, List, Literal, Sequence
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 10):
|
|
||||||
from itertools import pairwise
|
|
||||||
else:
|
|
||||||
|
|
||||||
def pairwise(iterable: Sequence) -> Iterable:
|
|
||||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
|
||||||
yield x, y
|
|
@ -1,15 +0,0 @@
|
|||||||
import sys
|
|
||||||
from typing import Iterable, List, Literal, Sequence
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 10):
|
|
||||||
from itertools import pairwise
|
|
||||||
else:
|
|
||||||
|
|
||||||
def pairwise(iterable: Sequence) -> Iterable:
|
|
||||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
|
||||||
yield x, y
|
|
@ -7,31 +7,27 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
|
|
||||||
from batdetect2.models import (
|
from batdetect2.models import (
|
||||||
BBoxHead,
|
BBoxHead,
|
||||||
ClassifierHead,
|
ClassifierHead,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
|
ModelOutput,
|
||||||
build_architecture,
|
build_architecture,
|
||||||
)
|
)
|
||||||
from batdetect2.models.typing import ModelOutput
|
|
||||||
from batdetect2.post_process import (
|
from batdetect2.post_process import (
|
||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
postprocess_model_outputs,
|
postprocess_model_outputs,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.targets import (
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
|
||||||
from batdetect2.train.losses import compute_loss
|
|
||||||
from batdetect2.train.targets import (
|
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_decoder,
|
build_decoder,
|
||||||
build_target_encoder,
|
build_target_encoder,
|
||||||
get_class_names,
|
get_class_names,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train import TrainExample, TrainingConfig, compute_loss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DetectorModel",
|
"DetectorModel",
|
||||||
@ -83,12 +79,9 @@ class DetectorModel(L.LightningModule):
|
|||||||
replacement_rules=self.config.targets.replace,
|
replacement_rules=self.config.targets.replace,
|
||||||
)
|
)
|
||||||
self.decoder = build_decoder(self.config.targets.classes)
|
self.decoder = build_decoder(self.config.targets.classes)
|
||||||
|
|
||||||
self.validation_predictions = []
|
|
||||||
|
|
||||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
features = self.backbone(spec)
|
features = self.backbone(spec)
|
||||||
detection_probs, classification_probs = self.classifier(features)
|
detection_probs, classification_probs = self.classifier(features)
|
||||||
size_preds = self.bbox(features)
|
size_preds = self.bbox(features)
|
||||||
@ -130,27 +123,6 @@ class DetectorModel(L.LightningModule):
|
|||||||
self.log("val/loss/size", losses.total, logger=True)
|
self.log("val/loss/size", losses.total, logger=True)
|
||||||
self.log("val/loss/classification", losses.total, logger=True)
|
self.log("val/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
dataloaders = self.trainer.val_dataloaders
|
|
||||||
assert isinstance(dataloaders, DataLoader)
|
|
||||||
dataset = dataloaders.dataset
|
|
||||||
assert isinstance(dataset, LabeledDataset)
|
|
||||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
|
||||||
|
|
||||||
clip_prediction = postprocess_model_outputs(
|
|
||||||
outputs,
|
|
||||||
clips=[clip_annotation.clip],
|
|
||||||
classes=self.class_names,
|
|
||||||
decoder=self.decoder,
|
|
||||||
config=self.config.postprocessing,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
matches = match_predictions_and_annotations(
|
|
||||||
clip_annotation,
|
|
||||||
clip_prediction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.validation_predictions.extend(matches)
|
|
||||||
|
|
||||||
def on_validation_epoch_end(self) -> None:
|
def on_validation_epoch_end(self) -> None:
|
||||||
self.validation_predictions.clear()
|
self.validation_predictions.clear()
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from soundevent import data
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PostprocessConfig",
|
"PostprocessConfig",
|
||||||
|
0
batdetect2/postprocess/__init__.py
Normal file
0
batdetect2/postprocess/__init__.py
Normal file
73
batdetect2/postprocess/arrays.py
Normal file
73
batdetect2/postprocess/arrays.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
|
from batdetect2.models import ModelOutput
|
||||||
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
|
||||||
|
|
||||||
|
def to_xarray(
|
||||||
|
output: ModelOutput,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
class_names: list[str],
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
):
|
||||||
|
detection = output.detection_probs
|
||||||
|
size = output.size_preds
|
||||||
|
classes = output.class_probs
|
||||||
|
features = output.features
|
||||||
|
|
||||||
|
if len(detection.shape) == 4:
|
||||||
|
if detection.shape[0] != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected a non-batched output or a batch of size 1, instead "
|
||||||
|
f"got an input of shape {detection.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
detection = detection.squeeze(dim=0)
|
||||||
|
size = size.squeeze(dim=0)
|
||||||
|
classes = classes.squeeze(dim=0)
|
||||||
|
features = features.squeeze(dim=0)
|
||||||
|
|
||||||
|
_, width, height = detection.shape
|
||||||
|
|
||||||
|
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||||
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
|
if classes.shape[0] != len(class_names):
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })"
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.Dataset(
|
||||||
|
data_vars={
|
||||||
|
"detection": (
|
||||||
|
[Dimensions.time.value, Dimensions.frequency.value],
|
||||||
|
detection.squeeze(dim=0).detach().numpy(),
|
||||||
|
),
|
||||||
|
"size": (
|
||||||
|
[
|
||||||
|
"dimension",
|
||||||
|
Dimensions.time.value,
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
],
|
||||||
|
detection.detach().numpy(),
|
||||||
|
),
|
||||||
|
"classes": (
|
||||||
|
[
|
||||||
|
"category",
|
||||||
|
Dimensions.time.value,
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
],
|
||||||
|
classes.detach().numpy(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
coords={
|
||||||
|
Dimensions.time.value: times,
|
||||||
|
Dimensions.frequency.value: freqs,
|
||||||
|
"dimension": ["width", "height"],
|
||||||
|
"category": class_names,
|
||||||
|
},
|
||||||
|
)
|
32
batdetect2/postprocess/config.py
Normal file
32
batdetect2/postprocess/config.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PostprocessConfig",
|
||||||
|
"load_postprocess_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
NMS_KERNEL_SIZE = 9
|
||||||
|
DETECTION_THRESHOLD = 0.01
|
||||||
|
TOP_K_PER_SEC = 200
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessConfig(BaseConfig):
|
||||||
|
"""Configuration for postprocessing model outputs."""
|
||||||
|
|
||||||
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||||
|
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
||||||
|
min_freq: int = Field(default=10000, gt=0)
|
||||||
|
max_freq: int = Field(default=120000, gt=0)
|
||||||
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
def load_postprocess_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PostprocessConfig:
|
||||||
|
return load_config(path, schema=PostprocessConfig, field=field)
|
50
batdetect2/postprocess/non_max_supression.py
Normal file
50
batdetect2/postprocess/non_max_supression.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
NMS_KERNEL_SIZE = 9
|
||||||
|
|
||||||
|
|
||||||
|
def non_max_suppression(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Run non-maximum suppression on a tensor.
|
||||||
|
|
||||||
|
This function removes values from the input tensor that are not local
|
||||||
|
maxima in the neighborhood of the given kernel size.
|
||||||
|
|
||||||
|
All non-maximum values are set to zero.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Input tensor.
|
||||||
|
kernel_size : Union[int, Tuple[int, int]], optional
|
||||||
|
Size of the neighborhood to consider for non-maximum suppression.
|
||||||
|
If an integer is given, the neighborhood will be a square of the
|
||||||
|
given size. If a tuple is given, the neighborhood will be a
|
||||||
|
rectangle with the given height and width.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Tensor with non-maximum suppressed values.
|
||||||
|
"""
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size_h = kernel_size
|
||||||
|
kernel_size_w = kernel_size
|
||||||
|
else:
|
||||||
|
kernel_size_h, kernel_size_w = kernel_size
|
||||||
|
|
||||||
|
pad_h = (kernel_size_h - 1) // 2
|
||||||
|
pad_w = (kernel_size_w - 1) // 2
|
||||||
|
|
||||||
|
hmax = torch.nn.functional.max_pool2d(
|
||||||
|
tensor,
|
||||||
|
(kernel_size_h, kernel_size_w),
|
||||||
|
stride=1,
|
||||||
|
padding=(pad_h, pad_w),
|
||||||
|
)
|
||||||
|
keep = (hmax == tensor).float()
|
||||||
|
return tensor * keep
|
17
batdetect2/postprocess/types.py
Normal file
17
batdetect2/postprocess/types.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from typing import Dict, NamedTuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BatDetect2Prediction",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2Prediction(NamedTuple):
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
low_freq: float
|
||||||
|
high_freq: float
|
||||||
|
detection_score: float
|
||||||
|
class_scores: Dict[str, float]
|
||||||
|
features: np.ndarray
|
@ -15,6 +15,8 @@ from batdetect2.preprocess.config import (
|
|||||||
load_preprocessing_config,
|
load_preprocessing_config,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
MAX_FREQ,
|
||||||
|
MIN_FREQ,
|
||||||
AmplitudeScaleConfig,
|
AmplitudeScaleConfig,
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
LogScaleConfig,
|
LogScaleConfig,
|
||||||
@ -31,6 +33,8 @@ __all__ = [
|
|||||||
"AudioConfig",
|
"AudioConfig",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
"LogScaleConfig",
|
"LogScaleConfig",
|
||||||
|
"MAX_FREQ",
|
||||||
|
"MIN_FREQ",
|
||||||
"PcenScaleConfig",
|
"PcenScaleConfig",
|
||||||
"PreprocessingConfig",
|
"PreprocessingConfig",
|
||||||
"ResampleConfig",
|
"ResampleConfig",
|
||||||
|
@ -31,7 +31,7 @@ def load_file_audio(
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
recording = data.Recording.from_file(path)
|
recording = data.Recording.from_file(path)
|
||||||
return load_recording_audio(
|
return load_recording_audio(
|
||||||
@ -46,7 +46,7 @@ def load_recording_audio(
|
|||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
@ -65,7 +65,7 @@ def load_clip_audio(
|
|||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
config = config or AudioConfig()
|
config = config or AudioConfig()
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ def resample_audio(
|
|||||||
wav: xr.DataArray,
|
wav: xr.DataArray,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
mode: str = "poly",
|
mode: str = "poly",
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
if "time" not in wav.dims:
|
if "time" not in wav.dims:
|
||||||
raise ValueError("Audio must have a time dimension")
|
raise ValueError("Audio must have a time dimension")
|
||||||
|
@ -11,6 +11,20 @@ from soundevent.arrays import operations as ops
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"STFTConfig",
|
||||||
|
"FrequencyConfig",
|
||||||
|
"LogScaleConfig",
|
||||||
|
"PcenScaleConfig",
|
||||||
|
"AmplitudeScaleConfig",
|
||||||
|
"Scales",
|
||||||
|
"SpectrogramConfig",
|
||||||
|
"compute_spectrogram",
|
||||||
|
]
|
||||||
|
|
||||||
|
MIN_FREQ = 10_000
|
||||||
|
MAX_FREQ = 120_000
|
||||||
|
|
||||||
|
|
||||||
class STFTConfig(BaseConfig):
|
class STFTConfig(BaseConfig):
|
||||||
window_duration: float = Field(default=0.002, gt=0)
|
window_duration: float = Field(default=0.002, gt=0)
|
||||||
@ -70,7 +84,7 @@ class SpectrogramConfig(BaseConfig):
|
|||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
wav: xr.DataArray,
|
wav: xr.DataArray,
|
||||||
config: Optional[SpectrogramConfig] = None,
|
config: Optional[SpectrogramConfig] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
config = config or SpectrogramConfig()
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
@ -124,7 +138,7 @@ def stft(
|
|||||||
window_duration: float,
|
window_duration: float,
|
||||||
window_overlap: float,
|
window_overlap: float,
|
||||||
window_fn: str = "hann",
|
window_fn: str = "hann",
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
step = arrays.get_dim_step(wave, dim="time")
|
step = arrays.get_dim_step(wave, dim="time")
|
||||||
@ -190,7 +204,7 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
|||||||
def scale_spectrogram(
|
def scale_spectrogram(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
scale: Scales,
|
scale: Scales,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
if scale.name == "log":
|
if scale.name == "log":
|
||||||
return scale_log(spec, dtype=dtype)
|
return scale_log(spec, dtype=dtype)
|
||||||
@ -230,7 +244,7 @@ def scale_pcen(
|
|||||||
|
|
||||||
def scale_log(
|
def scale_log(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
nfft = spec.attrs["nfft"]
|
nfft = spec.attrs["nfft"]
|
||||||
|
49
batdetect2/targets/__init__.py
Normal file
49
batdetect2/targets/__init__.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""Module that helps define what the training targets are.
|
||||||
|
|
||||||
|
The goal of this module is to configure how raw sound event annotations (tags)
|
||||||
|
are processed to determine which events are relevant for training and what
|
||||||
|
specific class label each relevant event should receive. Also, define how
|
||||||
|
predicted class labels map back to the original tag system for interpretation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from batdetect2.targets.labels import (
|
||||||
|
HeatmapsConfig,
|
||||||
|
LabelConfig,
|
||||||
|
generate_heatmaps,
|
||||||
|
load_label_config,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_decoder,
|
||||||
|
build_sound_event_filter,
|
||||||
|
build_target_encoder,
|
||||||
|
filter_sound_event,
|
||||||
|
get_class_names,
|
||||||
|
load_target_config,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
TagInfo,
|
||||||
|
TermInfo,
|
||||||
|
call_type,
|
||||||
|
get_tag_from_info,
|
||||||
|
individual,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HeatmapsConfig",
|
||||||
|
"LabelConfig",
|
||||||
|
"TagInfo",
|
||||||
|
"TargetConfig",
|
||||||
|
"TermInfo",
|
||||||
|
"build_decoder",
|
||||||
|
"build_sound_event_filter",
|
||||||
|
"build_target_encoder",
|
||||||
|
"call_type",
|
||||||
|
"filter_sound_event",
|
||||||
|
"generate_heatmaps",
|
||||||
|
"get_class_names",
|
||||||
|
"get_tag_from_info",
|
||||||
|
"individual",
|
||||||
|
"load_label_config",
|
||||||
|
"load_target_config",
|
||||||
|
]
|
@ -7,7 +7,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
from batdetect2.targets.terms import TagInfo, get_tag_from_info
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetConfig",
|
"TargetConfig",
|
370
batdetect2/targets/terms.py
Normal file
370
batdetect2/targets/terms.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
||||||
|
|
||||||
|
This module provides the necessary tools to declare, register, and manage the
|
||||||
|
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
||||||
|
sub-package. It establishes a consistent vocabulary for filtering,
|
||||||
|
transforming, and classifying sound events based on their annotations (Tags).
|
||||||
|
|
||||||
|
The core component is the `TermRegistry`, which maps unique string keys
|
||||||
|
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
||||||
|
terms using simple, consistent keys in configuration files and code.
|
||||||
|
|
||||||
|
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
||||||
|
programmatically, or loaded from external configuration files (e.g., YAML).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from inspect import getmembers
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from soundevent import data, terms
|
||||||
|
|
||||||
|
from batdetect2.configs import load_config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"call_type",
|
||||||
|
"individual",
|
||||||
|
"get_tag_from_info",
|
||||||
|
"TermInfo",
|
||||||
|
"TagInfo",
|
||||||
|
]
|
||||||
|
|
||||||
|
# The default key used to reference the 'generic_class' term.
|
||||||
|
# Often used implicitly when defining classification targets.
|
||||||
|
GENERIC_CLASS_KEY = "class"
|
||||||
|
|
||||||
|
|
||||||
|
call_type = data.Term(
|
||||||
|
name="soundevent:call_type",
|
||||||
|
label="Call Type",
|
||||||
|
definition=(
|
||||||
|
"A broad categorization of animal vocalizations based on their "
|
||||||
|
"intended function or purpose (e.g., social, distress, mating, "
|
||||||
|
"territorial, echolocation)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
"""Term representing the broad functional category of a vocalization."""
|
||||||
|
|
||||||
|
individual = data.Term(
|
||||||
|
name="soundevent:individual",
|
||||||
|
label="Individual",
|
||||||
|
definition=(
|
||||||
|
"An id for an individual animal. In the context of bioacoustic "
|
||||||
|
"annotation, this term is used to label vocalizations that are "
|
||||||
|
"attributed to a specific individual."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
"""Term used for tags identifying a specific individual animal."""
|
||||||
|
|
||||||
|
generic_class = data.Term(
|
||||||
|
name="soundevent:class",
|
||||||
|
label="Class",
|
||||||
|
definition=(
|
||||||
|
"A generic term representing the name of a class within a "
|
||||||
|
"classification model. Its specific meaning is determined by "
|
||||||
|
"the model's application."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
"""Generic term representing a classification model's output class label."""
|
||||||
|
|
||||||
|
|
||||||
|
class TermRegistry:
|
||||||
|
"""Manages a registry mapping unique keys to Term definitions.
|
||||||
|
|
||||||
|
This class acts as the central repository for the vocabulary of terms
|
||||||
|
used within the target definition process. It allows registering terms
|
||||||
|
with simple string keys and retrieving them consistently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
||||||
|
"""Initializes the TermRegistry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
terms: An optional dictionary of initial key-to-Term mappings
|
||||||
|
to populate the registry with. Defaults to an empty registry.
|
||||||
|
"""
|
||||||
|
self._terms = terms or {}
|
||||||
|
|
||||||
|
def add_term(self, key: str, term: data.Term) -> None:
|
||||||
|
"""Adds a Term object to the registry with the specified key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The unique string key to associate with the term.
|
||||||
|
term: The soundevent.data.Term object to register.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If a term with the provided key already exists in the
|
||||||
|
registry.
|
||||||
|
"""
|
||||||
|
if key in self._terms:
|
||||||
|
raise KeyError("A term with the provided key already exists.")
|
||||||
|
|
||||||
|
self._terms[key] = term
|
||||||
|
|
||||||
|
def get_term(self, key: str) -> data.Term:
|
||||||
|
"""Retrieves a registered term by its unique key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The unique string key of the term to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The corresponding soundevent.data.Term object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If no term with the specified key is found, with a
|
||||||
|
helpful message suggesting listing available keys.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._terms[key]
|
||||||
|
except KeyError as err:
|
||||||
|
raise KeyError(
|
||||||
|
"No term found for key "
|
||||||
|
f"'{key}'. Ensure it is registered or loaded. "
|
||||||
|
"Use `get_term_keys()` to list available terms."
|
||||||
|
) from err
|
||||||
|
|
||||||
|
def add_custom_term(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
uri: Optional[str] = None,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
definition: Optional[str] = None,
|
||||||
|
) -> data.Term:
|
||||||
|
"""Creates a new Term from attributes and adds it to the registry.
|
||||||
|
|
||||||
|
This is useful for defining terms directly in code or when loading
|
||||||
|
from configuration files where only attributes are provided.
|
||||||
|
|
||||||
|
If optional fields (`name`, `label`, `definition`) are not provided,
|
||||||
|
reasonable defaults are used (`key` for name/label, "Unknown" for
|
||||||
|
definition).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The unique string key for the new term.
|
||||||
|
name: The name for the new term (defaults to `key`).
|
||||||
|
uri: The URI for the new term (optional).
|
||||||
|
label: The display label for the new term (defaults to `key`).
|
||||||
|
definition: The definition for the new term (defaults to "Unknown").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The newly created and registered soundevent.data.Term object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If a term with the provided key already exists.
|
||||||
|
"""
|
||||||
|
term = data.Term(
|
||||||
|
name=name or key,
|
||||||
|
label=label or key,
|
||||||
|
uri=uri,
|
||||||
|
definition=definition or "Unknown",
|
||||||
|
)
|
||||||
|
self.add_term(key, term)
|
||||||
|
return term
|
||||||
|
|
||||||
|
def get_keys(self) -> List[str]:
|
||||||
|
"""Returns a list of all keys currently registered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings representing the keys of all registered terms.
|
||||||
|
"""
|
||||||
|
return list(self._terms.keys())
|
||||||
|
|
||||||
|
|
||||||
|
registry = TermRegistry(
|
||||||
|
terms=dict(
|
||||||
|
[
|
||||||
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
|
("call_type", call_type),
|
||||||
|
("individual", individual),
|
||||||
|
(GENERIC_CLASS_KEY, generic_class),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""The default, globally accessible TermRegistry instance.
|
||||||
|
|
||||||
|
It is pre-populated with standard terms from `soundevent.terms` and common
|
||||||
|
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
||||||
|
Functions in this module use this registry by default unless another instance
|
||||||
|
is explicitly passed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_term_from_key(
|
||||||
|
key: str,
|
||||||
|
term_registry: TermRegistry = registry,
|
||||||
|
) -> data.Term:
|
||||||
|
"""Convenience function to retrieve a term by key from a registry.
|
||||||
|
|
||||||
|
Uses the global default registry unless a specific `term_registry`
|
||||||
|
instance is provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The unique key of the term to retrieve.
|
||||||
|
term_registry: The TermRegistry instance to search in. Defaults to
|
||||||
|
the global `registry`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The corresponding soundevent.data.Term object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the key is not found in the specified registry.
|
||||||
|
"""
|
||||||
|
return term_registry.get_term(key)
|
||||||
|
|
||||||
|
|
||||||
|
def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
|
||||||
|
"""Convenience function to get all registered keys from a registry.
|
||||||
|
|
||||||
|
Uses the global default registry unless a specific `term_registry`
|
||||||
|
instance is provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
term_registry: The TermRegistry instance to query. Defaults to
|
||||||
|
the global `registry`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings representing the keys of all registered terms.
|
||||||
|
"""
|
||||||
|
return term_registry.get_keys()
|
||||||
|
|
||||||
|
|
||||||
|
class TagInfo(BaseModel):
|
||||||
|
"""Represents information needed to define a specific Tag.
|
||||||
|
|
||||||
|
This model is typically used in configuration files (e.g., YAML) to
|
||||||
|
specify tags used for filtering, target class definition, or associating
|
||||||
|
tags with output classes. It links a tag value to a term definition
|
||||||
|
via the term's registry key.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
value: The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
||||||
|
key: The key (alias) of the term associated with this tag, as
|
||||||
|
registered in the TermRegistry. Defaults to "class", implying
|
||||||
|
it represents a classification target label by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
value: str
|
||||||
|
key: str = GENERIC_CLASS_KEY
|
||||||
|
|
||||||
|
|
||||||
|
def get_tag_from_info(
|
||||||
|
tag_info: TagInfo,
|
||||||
|
term_registry: TermRegistry = registry,
|
||||||
|
) -> data.Tag:
|
||||||
|
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||||
|
|
||||||
|
Looks up the term using the key in the provided `tag_info` from the
|
||||||
|
specified registry and constructs a Tag object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag_info: The TagInfo object containing the value and term key.
|
||||||
|
term_registry: The TermRegistry instance to use for term lookup.
|
||||||
|
Defaults to the global `registry`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A soundevent.data.Tag object corresponding to the input info.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the term key specified in `tag_info.key` is not found
|
||||||
|
in the registry.
|
||||||
|
"""
|
||||||
|
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||||
|
return data.Tag(term=term, value=tag_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TermInfo(BaseModel):
|
||||||
|
"""Represents the definition of a Term within a configuration file.
|
||||||
|
|
||||||
|
This model allows users to define custom terms directly in configuration
|
||||||
|
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
||||||
|
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
key: The unique key (alias) that will be used to register and
|
||||||
|
reference this term.
|
||||||
|
label: The optional display label for the term. Defaults to `key`
|
||||||
|
if not provided during registration.
|
||||||
|
name: The optional formal name for the term. Defaults to `key`
|
||||||
|
if not provided during registration.
|
||||||
|
uri: The optional URI identifying the term (e.g., from a standard
|
||||||
|
vocabulary).
|
||||||
|
definition: The optional textual definition of the term. Defaults to
|
||||||
|
"Unknown" if not provided during registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
label: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
uri: Optional[str] = None
|
||||||
|
definition: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TermConfig(BaseModel):
|
||||||
|
"""Pydantic schema for loading a list of term definitions from config.
|
||||||
|
|
||||||
|
This model typically corresponds to a section in a configuration file
|
||||||
|
(e.g., YAML) containing a list of term definitions to be registered.
|
||||||
|
|
||||||
|
Example YAML structure:
|
||||||
|
```yaml
|
||||||
|
terms:
|
||||||
|
- key: species
|
||||||
|
uri: dwc:scientificName
|
||||||
|
label: Scientific Name
|
||||||
|
- key: my_custom_term
|
||||||
|
name: My Custom Term
|
||||||
|
definition: Describes a specific project attribute.
|
||||||
|
# ... more TermInfo definitions
|
||||||
|
```
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
terms: A list of TermInfo objects, each defining a term to be
|
||||||
|
registered. Defaults to an empty list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
terms: List[TermInfo] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_terms_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = registry,
|
||||||
|
) -> Dict[str, data.Term]:
|
||||||
|
"""Loads term definitions from a configuration file and registers them.
|
||||||
|
|
||||||
|
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
||||||
|
extracts the list of TermInfo definitions, and adds each one as a
|
||||||
|
custom term to the specified TermRegistry instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to the configuration file.
|
||||||
|
field: Optional key indicating a specific section within the config
|
||||||
|
file where the 'terms' list is located. If None, expects the
|
||||||
|
list directly at the top level or within a structure matching
|
||||||
|
TermConfig schema.
|
||||||
|
term_registry: The TermRegistry instance to add the loaded terms to.
|
||||||
|
Defaults to the global `registry`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping the keys of the newly added terms to their
|
||||||
|
corresponding Term objects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the config file path does not exist.
|
||||||
|
ValidationError: If the config file structure does not match the
|
||||||
|
TermConfig schema.
|
||||||
|
KeyError: If a term key loaded from the config conflicts with a key
|
||||||
|
already present in the registry.
|
||||||
|
"""
|
||||||
|
data = load_config(path, schema=TermConfig, field=field)
|
||||||
|
return {
|
||||||
|
info.key: term_registry.add_custom_term(
|
||||||
|
info.key,
|
||||||
|
name=info.name,
|
||||||
|
uri=info.uri,
|
||||||
|
label=info.label,
|
||||||
|
definition=info.definition,
|
||||||
|
)
|
||||||
|
for info in data.terms
|
||||||
|
}
|
@ -1,88 +0,0 @@
|
|||||||
from inspect import getmembers
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from soundevent import data, terms
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"call_type",
|
|
||||||
"individual",
|
|
||||||
"get_term_from_info",
|
|
||||||
"get_tag_from_info",
|
|
||||||
"TermInfo",
|
|
||||||
"TagInfo",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TermInfo(BaseModel):
|
|
||||||
label: Optional[str]
|
|
||||||
name: Optional[str]
|
|
||||||
uri: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class TagInfo(BaseModel):
|
|
||||||
value: str
|
|
||||||
term: Optional[TermInfo] = None
|
|
||||||
key: Optional[str] = None
|
|
||||||
label: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
call_type = data.Term(
|
|
||||||
name="soundevent:call_type",
|
|
||||||
label="Call Type",
|
|
||||||
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
|
||||||
)
|
|
||||||
|
|
||||||
individual = data.Term(
|
|
||||||
name="soundevent:individual",
|
|
||||||
label="Individual",
|
|
||||||
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ALL_TERMS = [
|
|
||||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
|
||||||
call_type,
|
|
||||||
individual,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
|
||||||
for term in ALL_TERMS:
|
|
||||||
if term_info.name and term_info.name == term.name:
|
|
||||||
return term
|
|
||||||
|
|
||||||
if term_info.label and term_info.label == term.label:
|
|
||||||
return term
|
|
||||||
|
|
||||||
if term_info.uri and term_info.uri == term.uri:
|
|
||||||
return term
|
|
||||||
|
|
||||||
if term_info.name is None:
|
|
||||||
if term_info.label is None:
|
|
||||||
raise ValueError("At least one of name or label must be provided.")
|
|
||||||
|
|
||||||
term_info.name = (
|
|
||||||
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if term_info.label is None:
|
|
||||||
term_info.label = term_info.name
|
|
||||||
|
|
||||||
return data.Term(
|
|
||||||
name=term_info.name,
|
|
||||||
label=term_info.label,
|
|
||||||
uri=term_info.uri,
|
|
||||||
definition="Unknown",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
|
||||||
if tag_info.term:
|
|
||||||
term = get_term_from_info(tag_info.term)
|
|
||||||
elif tag_info.key:
|
|
||||||
term = data.term_from_key(tag_info.key)
|
|
||||||
else:
|
|
||||||
raise ValueError("Either term or key must be provided in tag info.")
|
|
||||||
|
|
||||||
return data.Tag(term=term, value=tag_info.value)
|
|
@ -17,37 +17,26 @@ from batdetect2.train.dataset import (
|
|||||||
TrainExample,
|
TrainExample,
|
||||||
get_preprocessed_files,
|
get_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
from batdetect2.train.losses import compute_loss
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
generate_train_example,
|
generate_train_example,
|
||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.targets import (
|
|
||||||
TagInfo,
|
|
||||||
TargetConfig,
|
|
||||||
build_target_encoder,
|
|
||||||
load_target_config,
|
|
||||||
)
|
|
||||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
"LabelConfig",
|
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
"SubclipConfig",
|
"SubclipConfig",
|
||||||
"TagInfo",
|
|
||||||
"TargetConfig",
|
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
"TrainerConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"augment_example",
|
"augment_example",
|
||||||
"build_target_encoder",
|
"compute_loss",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"get_preprocessed_files",
|
"get_preprocessed_files",
|
||||||
"load_agumentation_config",
|
"load_agumentation_config",
|
||||||
"load_label_config",
|
|
||||||
"load_target_config",
|
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"load_trainer_config",
|
"load_trainer_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
|
55
batdetect2/train/callbacks.py
Normal file
55
batdetect2/train/callbacks.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from lightning import LightningModule, Trainer
|
||||||
|
from lightning.pytorch.callbacks import Callback
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.evaluate import match_predictions_and_annotations
|
||||||
|
from batdetect2.post_process import postprocess_model_outputs
|
||||||
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
|
from batdetect2.types import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationMetrics(Callback):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.predictions = []
|
||||||
|
|
||||||
|
def on_validation_epoch_start(
|
||||||
|
self,
|
||||||
|
trainer: Trainer,
|
||||||
|
pl_module: LightningModule,
|
||||||
|
) -> None:
|
||||||
|
self.predictions = []
|
||||||
|
return super().on_validation_epoch_start(trainer, pl_module)
|
||||||
|
|
||||||
|
def on_validation_batch_end( # type: ignore
|
||||||
|
self,
|
||||||
|
trainer: Trainer,
|
||||||
|
pl_module: LightningModule,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
batch: TrainExample,
|
||||||
|
batch_idx: int,
|
||||||
|
dataloader_idx: int = 0,
|
||||||
|
) -> None:
|
||||||
|
dataloaders = trainer.val_dataloaders
|
||||||
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
dataset = dataloaders.dataset
|
||||||
|
assert isinstance(dataset, LabeledDataset)
|
||||||
|
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||||
|
|
||||||
|
clip_prediction = postprocess_model_outputs(
|
||||||
|
outputs,
|
||||||
|
clips=[clip_annotation.clip],
|
||||||
|
classes=self.class_names,
|
||||||
|
decoder=self.decoder,
|
||||||
|
config=self.config.postprocessing,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
matches = match_predictions_and_annotations(
|
||||||
|
clip_annotation,
|
||||||
|
clip_prediction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validation_predictions.extend(matches)
|
||||||
|
return super().on_validation_batch_end(
|
||||||
|
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||||
|
)
|
@ -6,7 +6,7 @@ from torch.optim import Adam
|
|||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.models.typing import DetectionModel
|
from batdetect2.models.types import DetectionModel
|
||||||
from batdetect2.train.dataset import LabeledDataset
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.train.dataset import TrainExample
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -17,11 +17,12 @@ from batdetect2.preprocess import (
|
|||||||
compute_spectrogram,
|
compute_spectrogram,
|
||||||
load_clip_audio,
|
load_clip_audio,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
from batdetect2.targets import (
|
||||||
from batdetect2.train.targets import (
|
LabelConfig,
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_target_encoder,
|
|
||||||
build_sound_event_filter,
|
build_sound_event_filter,
|
||||||
|
build_target_encoder,
|
||||||
|
generate_heatmaps,
|
||||||
get_class_names,
|
get_class_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -91,7 +91,15 @@ convention = "numpy"
|
|||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = ["batdetect2", "tests"]
|
include = ["batdetect2", "tests"]
|
||||||
venvPath = "."
|
|
||||||
venv = ".venv"
|
|
||||||
pythonVersion = "3.9"
|
pythonVersion = "3.9"
|
||||||
pythonPlatform = "All"
|
pythonPlatform = "All"
|
||||||
|
exclude = [
|
||||||
|
"batdetect2/detector/",
|
||||||
|
"batdetect2/finetune",
|
||||||
|
"batdetect2/utils",
|
||||||
|
"batdetect2/plotting",
|
||||||
|
"batdetect2/plot",
|
||||||
|
"batdetect2/api",
|
||||||
|
"batdetect2/evaluate/legacy",
|
||||||
|
"batdetect2/train/legacy",
|
||||||
|
]
|
||||||
|
@ -1,11 +1,18 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Iterable, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from soundevent import data
|
from soundevent import data, terms
|
||||||
|
|
||||||
|
from batdetect2.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_target_encoder,
|
||||||
|
call_type,
|
||||||
|
get_class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -107,3 +114,102 @@ def recording_factory(wav_factory: Callable[..., Path]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _recording_factory
|
return _recording_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def recording(
|
||||||
|
recording_factory: Callable[..., data.Recording],
|
||||||
|
) -> data.Recording:
|
||||||
|
return recording_factory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clip(recording: data.Recording) -> data.Clip:
|
||||||
|
return data.Clip(recording=recording, start_time=0, end_time=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[
|
||||||
|
data.Tag(term=terms.scientific_name, value="Myotis myotis"),
|
||||||
|
data.Tag(term=call_type, value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def generic_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[0.34, 35_000, 0.348, 62_000]
|
||||||
|
),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[
|
||||||
|
data.Tag(term=terms.order, value="Chiroptera"),
|
||||||
|
data.Tag(term=call_type, value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def non_relevant_sound_event(
|
||||||
|
recording: data.Recording,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[0.22, 50_000, 0.24, 58_000]
|
||||||
|
),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=terms.scientific_name,
|
||||||
|
value="Muscardinus avellanarius",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clip_annotation(
|
||||||
|
clip: data.Clip,
|
||||||
|
echolocation_call: data.SoundEventAnnotation,
|
||||||
|
generic_call: data.SoundEventAnnotation,
|
||||||
|
non_relevant_sound_event: data.SoundEventAnnotation,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
return data.ClipAnnotation(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=[
|
||||||
|
echolocation_call,
|
||||||
|
generic_call,
|
||||||
|
non_relevant_sound_event,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_config() -> TargetConfig:
|
||||||
|
return TargetConfig()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def class_names(target_config: TargetConfig) -> List[str]:
|
||||||
|
return get_class_names(target_config.classes)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def encoder(
|
||||||
|
target_config: TargetConfig,
|
||||||
|
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||||
|
return build_target_encoder(
|
||||||
|
classes=target_config.classes,
|
||||||
|
replacement_rules=target_config.replace,
|
||||||
|
)
|
||||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.compat.data import load_annotation_project_from_dir
|
from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
@ -12,8 +12,14 @@ ROOT_DIR = Path(__file__).parent.parent.parent
|
|||||||
def test_load_example_annotation_project():
|
def test_load_example_annotation_project():
|
||||||
path = ROOT_DIR / "example_data" / "anns"
|
path = ROOT_DIR / "example_data" / "anns"
|
||||||
audio_dir = ROOT_DIR / "example_data" / "audio"
|
audio_dir = ROOT_DIR / "example_data" / "audio"
|
||||||
project = load_annotation_project_from_dir(path, audio_dir=audio_dir)
|
project = load_annotated_dataset(
|
||||||
|
BatDetect2FilesAnnotations(
|
||||||
|
name="test",
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
annotations_dir=path,
|
||||||
|
)
|
||||||
|
)
|
||||||
assert isinstance(project, data.AnnotationProject)
|
assert isinstance(project, data.AnnotationProject)
|
||||||
assert project.name == str(path)
|
assert project.name == "test"
|
||||||
assert len(project.clip_annotations) == 3
|
assert len(project.clip_annotations) == 3
|
||||||
assert len(project.tasks) == 3
|
assert len(project.tasks) == 3
|
||||||
|
@ -5,8 +5,8 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from batdetect2.compat.data import load_annotation_project_from_dir
|
|
||||||
from batdetect2.compat.params import get_training_preprocessing_config
|
from batdetect2.compat.params import get_training_preprocessing_config
|
||||||
|
from batdetect2.data import BatDetect2FilesAnnotations, load_annotated_dataset
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +26,8 @@ def test_can_generate_similar_training_inputs(
|
|||||||
old_parameters = json.loads((regression_dir / "params.json").read_text())
|
old_parameters = json.loads((regression_dir / "params.json").read_text())
|
||||||
config = get_training_preprocessing_config(old_parameters)
|
config = get_training_preprocessing_config(old_parameters)
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
|
||||||
for audio_file in example_audio_files:
|
for audio_file in example_audio_files:
|
||||||
example_file = regression_dir / f"{audio_file.name}.npz"
|
example_file = regression_dir / f"{audio_file.name}.npz"
|
||||||
|
|
||||||
@ -36,9 +38,12 @@ def test_can_generate_similar_training_inputs(
|
|||||||
size_mask = dataset["size_mask"]
|
size_mask = dataset["size_mask"]
|
||||||
class_mask = dataset["class_mask"]
|
class_mask = dataset["class_mask"]
|
||||||
|
|
||||||
project = load_annotation_project_from_dir(
|
project = load_annotated_dataset(
|
||||||
example_anns_dir,
|
BatDetect2FilesAnnotations(
|
||||||
audio_dir=example_audio_dir,
|
name="test",
|
||||||
|
annotations_dir=example_anns_dir,
|
||||||
|
audio_dir=example_audio_dir,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
clip_annotation = next(
|
clip_annotation = next(
|
||||||
@ -47,7 +52,12 @@ def test_can_generate_similar_training_inputs(
|
|||||||
if ann.clip.recording.path == audio_file
|
if ann.clip.recording.path == audio_file
|
||||||
)
|
)
|
||||||
|
|
||||||
new_dataset = generate_train_example(clip_annotation, config)
|
new_dataset = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
new_spec = new_dataset["spectrogram"].values
|
new_spec = new_dataset["spectrogram"].values
|
||||||
new_detection_mask = new_dataset["detection"].values
|
new_detection_mask = new_dataset["detection"].values
|
||||||
new_size_mask = new_dataset["size"].values
|
new_size_mask = new_dataset["size"].values
|
||||||
|
0
tests/test_postprocessing/__init__.py
Normal file
0
tests/test_postprocessing/__init__.py
Normal file
43
tests/test_postprocessing/test_arrays.py
Normal file
43
tests/test_postprocessing/test_arrays.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.modules import DetectorModel
|
||||||
|
from batdetect2.postprocess.arrays import to_xarray
|
||||||
|
from batdetect2.preprocess import preprocess_audio_clip
|
||||||
|
|
||||||
|
|
||||||
|
def test_this(clip: data.Clip, class_names: List[str]):
|
||||||
|
spec = xr.DataArray(
|
||||||
|
data=np.random.rand(100, 100),
|
||||||
|
dims=["time", "frequency"],
|
||||||
|
coords={
|
||||||
|
"time": np.linspace(0, 100, 100, endpoint=False),
|
||||||
|
"frequency": np.linspace(0, 100, 100, endpoint=False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = DetectorModel()
|
||||||
|
|
||||||
|
spec = preprocess_audio_clip(
|
||||||
|
clip,
|
||||||
|
config=model.config.preprocessing,
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
outputs = model(tensor)
|
||||||
|
|
||||||
|
arrays = to_xarray(
|
||||||
|
outputs,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
class_names=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(arrays)
|
||||||
|
|
||||||
|
assert False
|
0
tests/test_targets/__init__.py
Normal file
0
tests/test_targets/__init__.py
Normal file
175
tests/test_targets/test_terms.py
Normal file
175
tests/test_targets/test_terms.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import terms
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
TagInfo,
|
||||||
|
TermRegistry,
|
||||||
|
load_terms_from_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_initialization():
|
||||||
|
registry = TermRegistry()
|
||||||
|
assert registry._terms == {}
|
||||||
|
|
||||||
|
initial_terms = {
|
||||||
|
"test_term": data.Term(name="test", label="Test", definition="test")
|
||||||
|
}
|
||||||
|
registry = TermRegistry(terms=initial_terms)
|
||||||
|
assert registry._terms == initial_terms
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_add_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = data.Term(name="test", label="Test", definition="test")
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
assert registry._terms["test_key"] == term
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_get_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = data.Term(name="test", label="Test", definition="test")
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
retrieved_term = registry.get_term("test_key")
|
||||||
|
assert retrieved_term == term
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_add_custom_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = registry.add_custom_term(
|
||||||
|
"custom_key", name="custom", label="Custom", definition="A custom term"
|
||||||
|
)
|
||||||
|
assert registry._terms["custom_key"] == term
|
||||||
|
assert term.name == "custom"
|
||||||
|
assert term.label == "Custom"
|
||||||
|
assert term.definition == "A custom term"
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_add_duplicate_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = data.Term(name="test", label="Test", definition="test")
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_get_term_not_found():
|
||||||
|
registry = TermRegistry()
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
registry.get_term("non_existent_key")
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_get_keys():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term1 = data.Term(name="test1", label="Test1", definition="test")
|
||||||
|
term2 = data.Term(name="test2", label="Test2", definition="test")
|
||||||
|
registry.add_term("key1", term1)
|
||||||
|
registry.add_term("key2", term2)
|
||||||
|
keys = registry.get_keys()
|
||||||
|
assert set(keys) == {"key1", "key2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_term_from_key():
|
||||||
|
term = terms.get_term_from_key("call_type")
|
||||||
|
assert term == terms.call_type
|
||||||
|
|
||||||
|
custom_registry = TermRegistry()
|
||||||
|
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||||
|
custom_registry.add_term("custom_key", custom_term)
|
||||||
|
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
|
||||||
|
assert term == custom_term
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_term_keys():
|
||||||
|
keys = terms.get_term_keys()
|
||||||
|
assert "call_type" in keys
|
||||||
|
assert "individual" in keys
|
||||||
|
assert terms.GENERIC_CLASS_KEY in keys
|
||||||
|
|
||||||
|
custom_registry = TermRegistry()
|
||||||
|
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||||
|
custom_registry.add_term("custom_key", custom_term)
|
||||||
|
keys = terms.get_term_keys(term_registry=custom_registry)
|
||||||
|
assert "custom_key" in keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_tag_info_and_get_tag_from_info():
|
||||||
|
tag_info = TagInfo(value="Myotis myotis", key="call_type")
|
||||||
|
tag = terms.get_tag_from_info(tag_info)
|
||||||
|
assert tag.value == "Myotis myotis"
|
||||||
|
assert tag.term == terms.call_type
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tag_from_info_key_not_found():
|
||||||
|
tag_info = TagInfo(value="test", key="non_existent_key")
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
terms.get_tag_from_info(tag_info)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config(tmp_path):
|
||||||
|
config_data = {
|
||||||
|
"terms": [
|
||||||
|
{
|
||||||
|
"key": "species",
|
||||||
|
"name": "dwc:scientificName",
|
||||||
|
"label": "Scientific Name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "my_custom_term",
|
||||||
|
"name": "soundevent:custom_term",
|
||||||
|
"definition": "Describes a specific project attribute",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
loaded_terms = load_terms_from_config(config_file)
|
||||||
|
assert "species" in loaded_terms
|
||||||
|
assert "my_custom_term" in loaded_terms
|
||||||
|
assert loaded_terms["species"].name == "dwc:scientificName"
|
||||||
|
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config_file_not_found():
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_terms_from_config("non_existent_file.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config_validation_error(tmp_path):
|
||||||
|
config_data = {
|
||||||
|
"terms": [
|
||||||
|
{
|
||||||
|
"key": "species",
|
||||||
|
"uri": "dwc:scientificName",
|
||||||
|
"label": 123,
|
||||||
|
}, # Invalid label type
|
||||||
|
]
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
load_terms_from_config(config_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config_key_already_exists(tmp_path):
|
||||||
|
config_data = {
|
||||||
|
"terms": [
|
||||||
|
{
|
||||||
|
"key": "call_type",
|
||||||
|
"uri": "dwc:scientificName",
|
||||||
|
"label": "Scientific Name",
|
||||||
|
}, # Duplicate key
|
||||||
|
]
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
load_terms_from_config(config_file)
|
@ -30,8 +30,18 @@ def test_mix_examples(
|
|||||||
|
|
||||||
config = TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig()
|
||||||
|
|
||||||
example1 = generate_train_example(clip_annotation_1, config)
|
example1 = generate_train_example(
|
||||||
example2 = generate_train_example(clip_annotation_2, config)
|
clip_annotation_1,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
|
example2 = generate_train_example(
|
||||||
|
clip_annotation_2,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
|
|
||||||
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
||||||
|
|
||||||
@ -59,8 +69,18 @@ def test_mix_examples_of_different_durations(
|
|||||||
|
|
||||||
config = TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig()
|
||||||
|
|
||||||
example1 = generate_train_example(clip_annotation_1, config)
|
example1 = generate_train_example(
|
||||||
example2 = generate_train_example(clip_annotation_2, config)
|
clip_annotation_1,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
|
example2 = generate_train_example(
|
||||||
|
clip_annotation_2,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
|
|
||||||
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
||||||
|
|
||||||
@ -78,7 +98,12 @@ def test_add_echo(
|
|||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
config = TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig()
|
||||||
original = generate_train_example(clip_annotation_1, config)
|
original = generate_train_example(
|
||||||
|
clip_annotation_1,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
with_echo = add_echo(original, config=config.preprocessing)
|
with_echo = add_echo(original, config=config.preprocessing)
|
||||||
|
|
||||||
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
||||||
@ -94,7 +119,12 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
config = TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig()
|
||||||
original = generate_train_example(clip_annotation_1, config)
|
original = generate_train_example(
|
||||||
|
clip_annotation_1,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
subclip = select_subclip(original, width=100)
|
subclip = select_subclip(original, width=100)
|
||||||
|
|
||||||
assert subclip["spectrogram"].shape[1] == 100
|
assert subclip["spectrogram"].shape[1] == 100
|
||||||
@ -107,7 +137,12 @@ def test_add_echo_after_subclip(
|
|||||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
config = TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig()
|
||||||
original = generate_train_example(clip_annotation_1, config)
|
original = generate_train_example(
|
||||||
|
clip_annotation_1,
|
||||||
|
preprocessing_config=config.preprocessing,
|
||||||
|
target_config=config.target,
|
||||||
|
label_config=config.labels,
|
||||||
|
)
|
||||||
|
|
||||||
assert original.sizes["time"] > 512
|
assert original.sizes["time"] > 512
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.targets import generate_heatmaps
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
samplerate=256_000,
|
samplerate=256_000,
|
||||||
|
Loading…
Reference in New Issue
Block a user