mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
8 Commits
db2ad11743
...
cd4955d4f3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd4955d4f3 | ||
|
|
c73984b213 | ||
|
|
d8d2e5a2c2 | ||
|
|
b056d7d28d | ||
|
|
95a884ea16 | ||
|
|
b7ae526071 | ||
|
|
cf6d0d1ccc | ||
|
|
709b6355c2 |
@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
|
|||||||
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
||||||
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
||||||
|
|
||||||
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
|
**Example YAML Configuration (e.g., inside a filter rule):**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# ... inside a filtering configuration section ...
|
# ... inside a filtering configuration section ...
|
||||||
|
|||||||
@ -1,44 +1,47 @@
|
|||||||
targets:
|
targets:
|
||||||
classes:
|
detection_target:
|
||||||
classes:
|
name: bat
|
||||||
- name: myomys
|
match_if:
|
||||||
tags:
|
name: all_of
|
||||||
- value: Myotis mystacinus
|
conditions:
|
||||||
- name: pippip
|
- name: has_tag
|
||||||
tags:
|
tag: { key: event, value: Echolocation }
|
||||||
- value: Pipistrellus pipistrellus
|
- name: not
|
||||||
- name: eptser
|
condition:
|
||||||
tags:
|
name: has_tag
|
||||||
- value: Eptesicus serotinus
|
tag: { key: class, value: Unknown }
|
||||||
- name: rhifer
|
assign_tags:
|
||||||
tags:
|
|
||||||
- value: Rhinolophus ferrumequinum
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
generic_class:
|
|
||||||
- key: class
|
- key: class
|
||||||
value: Bat
|
value: Bat
|
||||||
|
|
||||||
filtering:
|
classification_targets:
|
||||||
rules:
|
- name: myomys
|
||||||
- match_type: all
|
tags:
|
||||||
tags:
|
- key: class
|
||||||
- key: event
|
value: Myotis mystacinus
|
||||||
value: Echolocation
|
- name: pippip
|
||||||
- match_type: exclude
|
tags:
|
||||||
tags:
|
- key: class
|
||||||
- key: class
|
value: Pipistrellus pipistrellus
|
||||||
value: Unknown
|
- name: eptser
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Eptesicus serotinus
|
||||||
|
- name: rhifer
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Rhinolophus ferrumequinum
|
||||||
|
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
audio:
|
audio:
|
||||||
|
samplerate: 256000
|
||||||
resample:
|
resample:
|
||||||
samplerate: 256000
|
enabled: True
|
||||||
method: "poly"
|
method: "poly"
|
||||||
scale: false
|
|
||||||
center: true
|
|
||||||
duration: null
|
|
||||||
|
|
||||||
spectrogram:
|
spectrogram:
|
||||||
stft:
|
stft:
|
||||||
@ -48,66 +51,66 @@ preprocess:
|
|||||||
frequencies:
|
frequencies:
|
||||||
max_freq: 120000
|
max_freq: 120000
|
||||||
min_freq: 10000
|
min_freq: 10000
|
||||||
pcen:
|
|
||||||
time_constant: 0.1
|
|
||||||
gain: 0.98
|
|
||||||
bias: 2
|
|
||||||
power: 0.5
|
|
||||||
scale: "amplitude"
|
|
||||||
size:
|
size:
|
||||||
height: 128
|
height: 128
|
||||||
resize_factor: 0.5
|
resize_factor: 0.5
|
||||||
spectral_mean_substraction: true
|
transforms:
|
||||||
peak_normalize: false
|
- name: pcen
|
||||||
|
time_constant: 0.1
|
||||||
|
gain: 0.98
|
||||||
|
bias: 2
|
||||||
|
power: 0.5
|
||||||
|
- name: spectral_mean_substraction
|
||||||
|
|
||||||
postprocess:
|
postprocess:
|
||||||
nms_kernel_size: 9
|
nms_kernel_size: 9
|
||||||
detection_threshold: 0.01
|
detection_threshold: 0.01
|
||||||
min_freq: 10000
|
|
||||||
max_freq: 120000
|
|
||||||
top_k_per_sec: 200
|
top_k_per_sec: 200
|
||||||
|
|
||||||
labels:
|
|
||||||
sigma: 3
|
|
||||||
|
|
||||||
model:
|
model:
|
||||||
input_height: 128
|
input_height: 128
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
encoder:
|
encoder:
|
||||||
layers:
|
layers:
|
||||||
- block_type: FreqCoordConvDown
|
- name: FreqCoordConvDown
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- block_type: FreqCoordConvDown
|
- name: FreqCoordConvDown
|
||||||
out_channels: 64
|
out_channels: 64
|
||||||
- block_type: LayerGroup
|
- name: LayerGroup
|
||||||
layers:
|
layers:
|
||||||
- block_type: FreqCoordConvDown
|
- name: FreqCoordConvDown
|
||||||
out_channels: 128
|
out_channels: 128
|
||||||
- block_type: ConvBlock
|
- name: ConvBlock
|
||||||
out_channels: 256
|
out_channels: 256
|
||||||
bottleneck:
|
bottleneck:
|
||||||
channels: 256
|
channels: 256
|
||||||
layers:
|
layers:
|
||||||
- block_type: SelfAttention
|
- name: SelfAttention
|
||||||
attention_channels: 256
|
attention_channels: 256
|
||||||
decoder:
|
decoder:
|
||||||
layers:
|
layers:
|
||||||
- block_type: FreqCoordConvUp
|
- name: FreqCoordConvUp
|
||||||
out_channels: 64
|
out_channels: 64
|
||||||
- block_type: FreqCoordConvUp
|
- name: FreqCoordConvUp
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- block_type: LayerGroup
|
- name: LayerGroup
|
||||||
layers:
|
layers:
|
||||||
- block_type: FreqCoordConvUp
|
- name: FreqCoordConvUp
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- block_type: ConvBlock
|
- name: ConvBlock
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
|
|
||||||
train:
|
train:
|
||||||
learning_rate: 0.001
|
learning_rate: 0.001
|
||||||
t_max: 100
|
t_max: 100
|
||||||
|
|
||||||
|
labels:
|
||||||
|
sigma: 3
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 40
|
||||||
|
|
||||||
dataloaders:
|
dataloaders:
|
||||||
train:
|
train:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
@ -115,7 +118,7 @@ train:
|
|||||||
shuffle: True
|
shuffle: True
|
||||||
|
|
||||||
val:
|
val:
|
||||||
batch_size: 8
|
batch_size: 1
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
@ -134,32 +137,34 @@ train:
|
|||||||
|
|
||||||
logger:
|
logger:
|
||||||
logger_type: csv
|
logger_type: csv
|
||||||
save_dir: outputs/log/
|
# save_dir: outputs/log/
|
||||||
name: logs
|
# name: logs
|
||||||
|
|
||||||
augmentations:
|
augmentations:
|
||||||
steps:
|
enabled: true
|
||||||
- augmentation_type: mix_audio
|
audio:
|
||||||
|
- name: mix_audio
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
min_weight: 0.3
|
min_weight: 0.3
|
||||||
max_weight: 0.7
|
max_weight: 0.7
|
||||||
- augmentation_type: add_echo
|
- name: add_echo
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
max_delay: 0.005
|
max_delay: 0.005
|
||||||
min_weight: 0.0
|
min_weight: 0.0
|
||||||
max_weight: 1.0
|
max_weight: 1.0
|
||||||
- augmentation_type: scale_volume
|
spectrogram:
|
||||||
|
- name: scale_volume
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
min_scaling: 0.0
|
min_scaling: 0.0
|
||||||
max_scaling: 2.0
|
max_scaling: 2.0
|
||||||
- augmentation_type: warp
|
- name: warp
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
delta: 0.04
|
delta: 0.04
|
||||||
- augmentation_type: mask_time
|
- name: mask_time
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
max_perc: 0.05
|
max_perc: 0.05
|
||||||
max_masks: 3
|
max_masks: 3
|
||||||
- augmentation_type: mask_freq
|
- name: mask_freq
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
max_perc: 0.10
|
max_perc: 0.10
|
||||||
max_masks: 3
|
max_masks: 3
|
||||||
|
|||||||
12
justfile
12
justfile
@ -92,19 +92,11 @@ clean-build:
|
|||||||
clean: clean-build clean-pyc clean-test clean-docs
|
clean: clean-build clean-pyc clean-test clean-docs
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
# Preprocess example data.
|
|
||||||
example-preprocess OPTIONS="":
|
|
||||||
batdetect2 preprocess \
|
|
||||||
--base-dir . \
|
|
||||||
--dataset-field datasets.train \
|
|
||||||
--config example_data/config.yaml \
|
|
||||||
{{OPTIONS}} \
|
|
||||||
example_data/config.yaml example_data/preprocessed
|
|
||||||
|
|
||||||
# Train on example data.
|
# Train on example data.
|
||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
batdetect2 train \
|
batdetect2 train \
|
||||||
--val-dir example_data/preprocessed \
|
--val-dataset example_data/dataset.yaml \
|
||||||
--config example_data/config.yaml \
|
--config example_data/config.yaml \
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
example_data/preprocessed
|
example_data/dataset.yaml
|
||||||
|
|||||||
@ -17,7 +17,7 @@ dependencies = [
|
|||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.8.1",
|
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.preprocess import preprocess
|
from batdetect2.cli.evaluate import evaluate_command
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -9,7 +9,7 @@ __all__ = [
|
|||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
"preprocess",
|
"evaluate_command",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
63
src/batdetect2/cli/evaluate.py
Normal file
63
src/batdetect2/cli/evaluate.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.evaluate.evaluate import evaluate
|
||||||
|
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||||
|
|
||||||
|
__all__ = ["evaluate_command"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(name="evaluate")
|
||||||
|
@click.argument("model-path", type=click.Path(exists=True))
|
||||||
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
|
@click.option("--output-dir", type=click.Path())
|
||||||
|
@click.option("--workers", type=int)
|
||||||
|
@click.option(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
count=True,
|
||||||
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
|
)
|
||||||
|
def evaluate_command(
|
||||||
|
model_path: Path,
|
||||||
|
test_dataset: Path,
|
||||||
|
output_dir: Optional[Path] = None,
|
||||||
|
workers: Optional[int] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
):
|
||||||
|
logger.remove()
|
||||||
|
if verbose == 0:
|
||||||
|
log_level = "WARNING"
|
||||||
|
elif verbose == 1:
|
||||||
|
log_level = "INFO"
|
||||||
|
else:
|
||||||
|
log_level = "DEBUG"
|
||||||
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
|
test_annotations = load_dataset_from_config(test_dataset)
|
||||||
|
logger.debug(
|
||||||
|
"Loaded {num_annotations} test examples",
|
||||||
|
num_annotations=len(test_annotations),
|
||||||
|
)
|
||||||
|
|
||||||
|
model, train_config = load_model_from_checkpoint(model_path)
|
||||||
|
|
||||||
|
df, results = evaluate(
|
||||||
|
model,
|
||||||
|
test_annotations,
|
||||||
|
config=train_config,
|
||||||
|
num_workers=workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
df.to_csv(output_dir / "results.csv")
|
||||||
@ -1,142 +0,0 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
|
||||||
import yaml
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.train.preprocess import (
|
|
||||||
TrainPreprocessConfig,
|
|
||||||
load_train_preprocessing_config,
|
|
||||||
preprocess_dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["preprocess"]
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
@click.argument(
|
|
||||||
"dataset_config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
)
|
|
||||||
@click.argument(
|
|
||||||
"output",
|
|
||||||
type=click.Path(),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--dataset-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"Specifies the key to access the dataset information within the "
|
|
||||||
"dataset configuration file, if the information is nested inside a "
|
|
||||||
"dictionary. If the dataset information is at the top level of the "
|
|
||||||
"config file, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--base-dir",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"The main directory where your audio recordings and annotation "
|
|
||||||
"files are stored. This helps the program find your data, "
|
|
||||||
"especially if the paths in your dataset configuration file "
|
|
||||||
"are relative."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the configuration file. This file tells "
|
|
||||||
"the program how to prepare your audio data before training, such "
|
|
||||||
"as resampling or applying filters."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the preprocessing settings are inside a nested dictionary "
|
|
||||||
"within the preprocessing configuration file, specify the key "
|
|
||||||
"here to access them. If the preprocessing settings are at the "
|
|
||||||
"top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
help=(
|
|
||||||
"The maximum number of computer cores to use when processing "
|
|
||||||
"your audio data. Using more cores can speed up the preprocessing, "
|
|
||||||
"but don't use more than your computer has available. By default, "
|
|
||||||
"the program will use all available cores."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"-v",
|
|
||||||
"--verbose",
|
|
||||||
count=True,
|
|
||||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
|
||||||
)
|
|
||||||
def preprocess(
|
|
||||||
dataset_config: Path,
|
|
||||||
output: Path,
|
|
||||||
base_dir: Optional[Path] = None,
|
|
||||||
config: Optional[Path] = None,
|
|
||||||
config_field: Optional[str] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
dataset_field: Optional[str] = None,
|
|
||||||
verbose: int = 0,
|
|
||||||
):
|
|
||||||
logger.remove()
|
|
||||||
if verbose == 0:
|
|
||||||
log_level = "WARNING"
|
|
||||||
elif verbose == 1:
|
|
||||||
log_level = "INFO"
|
|
||||||
else:
|
|
||||||
log_level = "DEBUG"
|
|
||||||
logger.add(sys.stderr, level=log_level)
|
|
||||||
|
|
||||||
logger.info("Starting preprocessing.")
|
|
||||||
|
|
||||||
output = Path(output)
|
|
||||||
logger.info("Will save outputs to {output}", output=output)
|
|
||||||
|
|
||||||
base_dir = base_dir or Path.cwd()
|
|
||||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
|
||||||
|
|
||||||
if config:
|
|
||||||
logger.info(
|
|
||||||
"Loading preprocessing config from: {config}", config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
conf = (
|
|
||||||
load_train_preprocessing_config(config, field=config_field)
|
|
||||||
if config is not None
|
|
||||||
else TrainPreprocessConfig()
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Preprocessing config:\n{conf}",
|
|
||||||
conf=yaml.dump(conf.model_dump()),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = load_dataset_from_config(
|
|
||||||
dataset_config,
|
|
||||||
field=dataset_field,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Loaded {num_examples} annotated clips from the configured dataset",
|
|
||||||
num_examples=len(dataset),
|
|
||||||
)
|
|
||||||
|
|
||||||
preprocess_dataset(
|
|
||||||
dataset,
|
|
||||||
conf,
|
|
||||||
output=output,
|
|
||||||
max_workers=num_workers,
|
|
||||||
)
|
|
||||||
@ -20,6 +20,8 @@ __all__ = ["train_command"]
|
|||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
|
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||||
|
@click.option("--log-dir", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@ -34,6 +36,8 @@ def train_command(
|
|||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
val_dataset: Optional[Path] = None,
|
val_dataset: Optional[Path] = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Optional[Path] = None,
|
||||||
|
ckpt_dir: Optional[Path] = None,
|
||||||
|
log_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -83,4 +87,6 @@ def train_command(
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=ckpt_dir,
|
||||||
)
|
)
|
||||||
|
|||||||
61
src/batdetect2/data/_core.py
Normal file
61
src/batdetect2/data/_core.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from typing import Generic, Protocol, Type, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Registry",
|
||||||
|
]
|
||||||
|
|
||||||
|
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||||
|
T_Type = TypeVar("T_Type", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LogicProtocol(Generic[T_Config, T_Type], Protocol):
|
||||||
|
"""A generic protocol for the logic classes (conditions or transforms)."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: T_Config) -> T_Type: ...
|
||||||
|
|
||||||
|
|
||||||
|
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
|
||||||
|
|
||||||
|
|
||||||
|
class Registry(Generic[T_Type]):
|
||||||
|
"""A generic class to create and manage a registry of items."""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self._name = name
|
||||||
|
self._registry = {}
|
||||||
|
|
||||||
|
def register(self, config_cls: Type[T_Config]):
|
||||||
|
"""A decorator factory to register a new item."""
|
||||||
|
fields = config_cls.model_fields
|
||||||
|
|
||||||
|
if "name" not in fields:
|
||||||
|
raise ValueError("Configuration object must have a 'name' field.")
|
||||||
|
|
||||||
|
name = fields["name"].default
|
||||||
|
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError("'name' field must be a string literal.")
|
||||||
|
|
||||||
|
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
|
||||||
|
self._registry[name] = logic_cls
|
||||||
|
return logic_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def build(self, config: BaseModel) -> T_Type:
|
||||||
|
"""Builds a logic instance from a config object."""
|
||||||
|
|
||||||
|
name = getattr(config, "name") # noqa: B009
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
raise ValueError("Config does not have a name field")
|
||||||
|
|
||||||
|
if name not in self._registry:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"No {self._name} with name '{name}' is registered."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._registry[name].from_config(config)
|
||||||
@ -14,8 +14,9 @@ format-specific loading function to retrieve the annotations as a standard
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.annotations.aoef import (
|
from batdetect2.data.annotations.aoef import (
|
||||||
@ -42,10 +43,13 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
AnnotationFormats = Union[
|
AnnotationFormats = Annotated[
|
||||||
BatDetect2MergedAnnotations,
|
Union[
|
||||||
BatDetect2FilesAnnotations,
|
BatDetect2MergedAnnotations,
|
||||||
AOEFAnnotations,
|
BatDetect2FilesAnnotations,
|
||||||
|
AOEFAnnotations,
|
||||||
|
],
|
||||||
|
Field(discriminator="format"),
|
||||||
]
|
]
|
||||||
"""Type Alias representing all supported data source configurations.
|
"""Type Alias representing all supported data source configurations.
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import get_term_from_key
|
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@ -92,15 +90,15 @@ def annotation_to_sound_event(
|
|||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key),
|
key=label_key, # type: ignore
|
||||||
value=annotation.label,
|
value=annotation.label,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(event_key),
|
key=event_key, # type: ignore
|
||||||
value=annotation.event,
|
value=annotation.event,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(individual_key),
|
key=individual_key, # type: ignore
|
||||||
value=str(annotation.individual),
|
value=str(annotation.individual),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -125,7 +123,7 @@ def file_annotation_to_clip(
|
|||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key),
|
key=label_key, # type: ignore
|
||||||
value=file_annotation.label,
|
value=file_annotation.label,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -157,7 +155,8 @@ def file_annotation_to_clip_annotation(
|
|||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key), value=file_annotation.label
|
key=label_key, # type: ignore
|
||||||
|
value=file_annotation.label,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
sound_events=[
|
sound_events=[
|
||||||
|
|||||||
287
src/batdetect2/data/conditions.py
Normal file
287
src/batdetect2/data/conditions.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Annotated, List, Literal, Sequence, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.data._core import Registry
|
||||||
|
|
||||||
|
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
|
_conditions: Registry[SoundEventCondition] = Registry("condition")
|
||||||
|
|
||||||
|
|
||||||
|
class HasTagConfig(BaseConfig):
|
||||||
|
name: Literal["has_tag"] = "has_tag"
|
||||||
|
tag: data.Tag
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(HasTagConfig)
|
||||||
|
class HasTag:
|
||||||
|
def __init__(self, tag: data.Tag):
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return self.tag in sound_event_annotation.tags
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: HasTagConfig):
|
||||||
|
return cls(tag=config.tag)
|
||||||
|
|
||||||
|
|
||||||
|
class HasAllTagsConfig(BaseConfig):
|
||||||
|
name: Literal["has_all_tags"] = "has_all_tags"
|
||||||
|
tags: List[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(HasAllTagsConfig)
|
||||||
|
class HasAllTags:
|
||||||
|
def __init__(self, tags: List[data.Tag]):
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
|
self.tags = set(tags)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return self.tags.issubset(sound_event_annotation.tags)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: HasAllTagsConfig):
|
||||||
|
return cls(tags=config.tags)
|
||||||
|
|
||||||
|
|
||||||
|
class HasAnyTagConfig(BaseConfig):
|
||||||
|
name: Literal["has_any_tag"] = "has_any_tag"
|
||||||
|
tags: List[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(HasAnyTagConfig)
|
||||||
|
class HasAnyTag:
|
||||||
|
def __init__(self, tags: List[data.Tag]):
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
|
self.tags = set(tags)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return bool(self.tags.intersection(sound_event_annotation.tags))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: HasAnyTagConfig):
|
||||||
|
return cls(tags=config.tags)
|
||||||
|
|
||||||
|
|
||||||
|
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||||
|
|
||||||
|
|
||||||
|
class DurationConfig(BaseConfig):
|
||||||
|
name: Literal["duration"] = "duration"
|
||||||
|
operator: Operator
|
||||||
|
seconds: float
|
||||||
|
|
||||||
|
|
||||||
|
def _build_comparator(
|
||||||
|
operator: Operator, value: float
|
||||||
|
) -> Callable[[float], bool]:
|
||||||
|
if operator == "gt":
|
||||||
|
return lambda x: x > value
|
||||||
|
|
||||||
|
if operator == "gte":
|
||||||
|
return lambda x: x >= value
|
||||||
|
|
||||||
|
if operator == "lt":
|
||||||
|
return lambda x: x < value
|
||||||
|
|
||||||
|
if operator == "lte":
|
||||||
|
return lambda x: x <= value
|
||||||
|
|
||||||
|
if operator == "eq":
|
||||||
|
return lambda x: x == value
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid operator {operator}")
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(DurationConfig)
|
||||||
|
class Duration:
|
||||||
|
def __init__(self, operator: Operator, seconds: float):
|
||||||
|
self.operator = operator
|
||||||
|
self.seconds = seconds
|
||||||
|
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> bool:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
return self._comparator(duration)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: DurationConfig):
|
||||||
|
return cls(operator=config.operator, seconds=config.seconds)
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyConfig(BaseConfig):
|
||||||
|
name: Literal["frequency"] = "frequency"
|
||||||
|
boundary: Literal["low", "high"]
|
||||||
|
operator: Operator
|
||||||
|
hertz: float
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(FrequencyConfig)
|
||||||
|
class Frequency:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
operator: Operator,
|
||||||
|
boundary: Literal["low", "high"],
|
||||||
|
hertz: float,
|
||||||
|
):
|
||||||
|
self.operator = operator
|
||||||
|
self.hertz = hertz
|
||||||
|
self.boundary = boundary
|
||||||
|
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> bool:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Automatically false if geometry does not have a frequency range
|
||||||
|
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
|
if self.boundary == "low":
|
||||||
|
return self._comparator(low_freq)
|
||||||
|
|
||||||
|
return self._comparator(high_freq)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: FrequencyConfig):
|
||||||
|
return cls(
|
||||||
|
operator=config.operator,
|
||||||
|
boundary=config.boundary,
|
||||||
|
hertz=config.hertz,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AllOfConfig(BaseConfig):
|
||||||
|
name: Literal["all_of"] = "all_of"
|
||||||
|
conditions: Sequence["SoundEventConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(AllOfConfig)
|
||||||
|
class AllOf:
|
||||||
|
def __init__(self, conditions: List[SoundEventCondition]):
|
||||||
|
self.conditions = conditions
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return all(c(sound_event_annotation) for c in self.conditions)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: AllOfConfig):
|
||||||
|
conditions = [
|
||||||
|
build_sound_event_condition(cond) for cond in config.conditions
|
||||||
|
]
|
||||||
|
return cls(conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class AnyOfConfig(BaseConfig):
|
||||||
|
name: Literal["any_of"] = "any_of"
|
||||||
|
conditions: List["SoundEventConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(AnyOfConfig)
|
||||||
|
class AnyOf:
|
||||||
|
def __init__(self, conditions: List[SoundEventCondition]):
|
||||||
|
self.conditions = conditions
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return any(c(sound_event_annotation) for c in self.conditions)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: AnyOfConfig):
|
||||||
|
conditions = [
|
||||||
|
build_sound_event_condition(cond) for cond in config.conditions
|
||||||
|
]
|
||||||
|
return cls(conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class NotConfig(BaseConfig):
|
||||||
|
name: Literal["not"] = "not"
|
||||||
|
condition: "SoundEventConditionConfig"
|
||||||
|
|
||||||
|
|
||||||
|
@_conditions.register(NotConfig)
|
||||||
|
class Not:
|
||||||
|
def __init__(self, condition: SoundEventCondition):
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return not self.condition(sound_event_annotation)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: NotConfig):
|
||||||
|
condition = build_sound_event_condition(config.condition)
|
||||||
|
return cls(condition)
|
||||||
|
|
||||||
|
|
||||||
|
SoundEventConditionConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
HasTagConfig,
|
||||||
|
HasAllTagsConfig,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
DurationConfig,
|
||||||
|
FrequencyConfig,
|
||||||
|
AllOfConfig,
|
||||||
|
AnyOfConfig,
|
||||||
|
NotConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_sound_event_condition(
|
||||||
|
config: SoundEventConditionConfig,
|
||||||
|
) -> SoundEventCondition:
|
||||||
|
return _conditions.build(config)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_clip_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
condition: SoundEventCondition,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if condition(sound_event)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
@ -19,7 +19,7 @@ The core components are:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -31,6 +31,17 @@ from batdetect2.data.annotations import (
|
|||||||
AnnotationFormats,
|
AnnotationFormats,
|
||||||
load_annotated_dataset,
|
load_annotated_dataset,
|
||||||
)
|
)
|
||||||
|
from batdetect2.data.conditions import (
|
||||||
|
SoundEventConditionConfig,
|
||||||
|
build_sound_event_condition,
|
||||||
|
filter_clip_annotation,
|
||||||
|
)
|
||||||
|
from batdetect2.data.transforms import (
|
||||||
|
ApplyAll,
|
||||||
|
SoundEventTransformConfig,
|
||||||
|
build_sound_event_transform,
|
||||||
|
transform_clip_annotation,
|
||||||
|
)
|
||||||
from batdetect2.targets.terms import data_source
|
from batdetect2.targets.terms import data_source
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -52,79 +63,68 @@ sources.
|
|||||||
|
|
||||||
|
|
||||||
class DatasetConfig(BaseConfig):
|
class DatasetConfig(BaseConfig):
|
||||||
"""Configuration model defining the structure of a BatDetect2 dataset.
|
"""Configuration model defining the structure of a BatDetect2 dataset."""
|
||||||
|
|
||||||
This class is typically loaded from a YAML file and describes the components
|
|
||||||
of the dataset, including metadata and a list of data sources.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
|
|
||||||
description : str
|
|
||||||
A longer description of the dataset's contents, origin, purpose, etc.
|
|
||||||
sources : List[AnnotationFormats]
|
|
||||||
A list defining the different data sources contributing to this
|
|
||||||
dataset. Each item in the list must conform to one of the Pydantic
|
|
||||||
models defined in the `AnnotationFormats` type union. The specific
|
|
||||||
model used for each source is determined by the mandatory `format`
|
|
||||||
field within the source's configuration, allowing BatDetect2 to use the
|
|
||||||
correct parser for different annotation styles.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
sources: List[
|
sources: List[AnnotationFormats]
|
||||||
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
|
||||||
]
|
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
||||||
|
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
dataset: DatasetConfig,
|
config: DatasetConfig,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load all clip annotations from the sources defined in a DatasetConfig.
|
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||||
|
|
||||||
Iterates through each data source specified in the `dataset_config`,
|
|
||||||
delegates the loading and parsing of that source's annotations to
|
|
||||||
`batdetect2.data.annotations.load_annotated_dataset` (which handles
|
|
||||||
different data formats), and aggregates all resulting `ClipAnnotation`
|
|
||||||
objects into a single flat list.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
dataset_config : DatasetConfig
|
|
||||||
The configuration object describing the dataset and its sources.
|
|
||||||
base_dir : Path, optional
|
|
||||||
An optional base directory path. If provided, relative paths for
|
|
||||||
metadata files or data directories within the `dataset_config`'s
|
|
||||||
sources might be resolved relative to this directory. Defaults to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Dataset (List[data.ClipAnnotation])
|
|
||||||
A flat list containing all loaded `ClipAnnotation` metadata objects
|
|
||||||
from all specified sources.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
Exception
|
|
||||||
Can raise various exceptions during the delegated loading process
|
|
||||||
(`load_annotated_dataset`) if files are not found, cannot be parsed
|
|
||||||
according to the specified format, or other I/O errors occur.
|
|
||||||
"""
|
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
for source in dataset.sources:
|
|
||||||
|
condition = (
|
||||||
|
build_sound_event_condition(config.sound_event_filter)
|
||||||
|
if config.sound_event_filter is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
transform = (
|
||||||
|
ApplyAll(
|
||||||
|
[
|
||||||
|
build_sound_event_transform(step)
|
||||||
|
for step in config.sound_event_transforms
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if config.sound_event_transforms
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
for source in config.sources:
|
||||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||||
num_examples=len(annotated_source.clip_annotations),
|
num_examples=len(annotated_source.clip_annotations),
|
||||||
source_name=source.name,
|
source_name=source.name,
|
||||||
)
|
)
|
||||||
clip_annotations.extend(
|
|
||||||
insert_source_tag(clip_annotation, source)
|
for clip_annotation in annotated_source.clip_annotations:
|
||||||
for clip_annotation in annotated_source.clip_annotations
|
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||||
)
|
|
||||||
|
if condition is not None:
|
||||||
|
clip_annotation = filter_clip_annotation(
|
||||||
|
clip_annotation,
|
||||||
|
condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
if transform is not None:
|
||||||
|
clip_annotation = transform_clip_annotation(
|
||||||
|
clip_annotation,
|
||||||
|
transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_annotations.append(clip_annotation)
|
||||||
|
|
||||||
return clip_annotations
|
return clip_annotations
|
||||||
|
|
||||||
|
|
||||||
@ -161,7 +161,6 @@ def insert_source_tag(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: add documentation
|
|
||||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||||
|
|
||||||
|
|||||||
@ -10,16 +10,8 @@ from batdetect2.typing.targets import TargetProtocol
|
|||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
apply_filter: bool = True,
|
|
||||||
apply_transform: bool = True,
|
|
||||||
exclude_generic: bool = True,
|
|
||||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||||
"""Iterate over sound events in a dataset, applying filtering and
|
"""Iterate over sound events in a dataset.
|
||||||
transformations.
|
|
||||||
|
|
||||||
This generator function processes sound event annotations from a given
|
|
||||||
dataset, allowing for optional filtering, transformation, and exclusion of
|
|
||||||
unclassifiable (generic) events based on the provided target definitions.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -29,18 +21,6 @@ def iterate_over_sound_events(
|
|||||||
targets : TargetProtocol
|
targets : TargetProtocol
|
||||||
An object implementing the `TargetProtocol`, which provides methods
|
An object implementing the `TargetProtocol`, which provides methods
|
||||||
for filtering, transforming, and encoding sound events.
|
for filtering, transforming, and encoding sound events.
|
||||||
apply_filter : bool, optional
|
|
||||||
If True, sound events will be filtered using `targets.filter()`.
|
|
||||||
Only events for which `targets.filter()` returns True will be yielded.
|
|
||||||
Defaults to True.
|
|
||||||
apply_transform : bool, optional
|
|
||||||
If True, sound events will be transformed using `targets.transform()`
|
|
||||||
before being yielded. Defaults to True.
|
|
||||||
exclude_generic : bool, optional
|
|
||||||
If True, sound events that result in a `None` class name after
|
|
||||||
`targets.encode()` will be excluded. This is typically used to
|
|
||||||
filter out events that cannot be mapped to a specific target class.
|
|
||||||
Defaults to True.
|
|
||||||
|
|
||||||
Yields
|
Yields
|
||||||
------
|
------
|
||||||
@ -63,17 +43,9 @@ def iterate_over_sound_events(
|
|||||||
"""
|
"""
|
||||||
for clip_annotation in dataset:
|
for clip_annotation in dataset:
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
if apply_filter:
|
if not targets.filter(sound_event_annotation):
|
||||||
if not targets.filter(sound_event_annotation):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if apply_transform:
|
|
||||||
sound_event_annotation = targets.transform(
|
|
||||||
sound_event_annotation
|
|
||||||
)
|
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event_annotation)
|
|
||||||
if class_name is None and exclude_generic:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
class_name = targets.encode_class(sound_event_annotation)
|
||||||
|
|
||||||
yield class_name, sound_event_annotation
|
yield class_name, sound_event_annotation
|
||||||
|
|||||||
250
src/batdetect2/data/transforms.py
Normal file
250
src/batdetect2/data/transforms.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.data._core import Registry
|
||||||
|
from batdetect2.data.conditions import (
|
||||||
|
SoundEventCondition,
|
||||||
|
SoundEventConditionConfig,
|
||||||
|
build_sound_event_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
SoundEventTransform = Callable[
|
||||||
|
[data.SoundEventAnnotation],
|
||||||
|
data.SoundEventAnnotation,
|
||||||
|
]
|
||||||
|
|
||||||
|
_transforms: Registry[SoundEventTransform] = Registry("transform")
|
||||||
|
|
||||||
|
|
||||||
|
class SetFrequencyBoundConfig(BaseConfig):
|
||||||
|
name: Literal["set_frequency"] = "set_frequency"
|
||||||
|
boundary: Literal["low", "high"] = "low"
|
||||||
|
hertz: float
|
||||||
|
|
||||||
|
|
||||||
|
@_transforms.register(SetFrequencyBoundConfig)
|
||||||
|
class SetFrequencyBound:
|
||||||
|
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
|
||||||
|
self.hertz = hertz
|
||||||
|
self.boundary = boundary
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
sound_event = sound_event_annotation.sound_event
|
||||||
|
geometry = sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
if not isinstance(geometry, data.BoundingBox):
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = geometry.coordinates
|
||||||
|
|
||||||
|
if self.boundary == "low":
|
||||||
|
low_freq = self.hertz
|
||||||
|
high_freq = max(high_freq, low_freq)
|
||||||
|
|
||||||
|
elif self.boundary == "high":
|
||||||
|
high_freq = self.hertz
|
||||||
|
low_freq = min(high_freq, low_freq)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(
|
||||||
|
coordinates=[start_time, low_freq, end_time, high_freq],
|
||||||
|
)
|
||||||
|
|
||||||
|
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
|
||||||
|
return sound_event_annotation.model_copy(
|
||||||
|
update=dict(sound_event=sound_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: SetFrequencyBoundConfig):
|
||||||
|
return cls(hertz=config.hertz, boundary=config.boundary)
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyIfConfig(BaseConfig):
|
||||||
|
name: Literal["apply_if"] = "apply_if"
|
||||||
|
transform: "SoundEventTransformConfig"
|
||||||
|
condition: SoundEventConditionConfig
|
||||||
|
|
||||||
|
|
||||||
|
@_transforms.register(ApplyIfConfig)
|
||||||
|
class ApplyIf:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
condition: SoundEventCondition,
|
||||||
|
transform: SoundEventTransform,
|
||||||
|
):
|
||||||
|
self.condition = condition
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
if not self.condition(sound_event_annotation):
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
return self.transform(sound_event_annotation)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: ApplyIfConfig):
|
||||||
|
transform = build_sound_event_transform(config.transform)
|
||||||
|
condition = build_sound_event_condition(config.condition)
|
||||||
|
return cls(condition=condition, transform=transform)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceTagConfig(BaseConfig):
|
||||||
|
name: Literal["replace_tag"] = "replace_tag"
|
||||||
|
original: data.Tag
|
||||||
|
replacement: data.Tag
|
||||||
|
|
||||||
|
|
||||||
|
@_transforms.register(ReplaceTagConfig)
|
||||||
|
class ReplaceTag:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
original: data.Tag,
|
||||||
|
replacement: data.Tag,
|
||||||
|
):
|
||||||
|
self.original = original
|
||||||
|
self.replacement = replacement
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
tags = []
|
||||||
|
|
||||||
|
for tag in sound_event_annotation.tags:
|
||||||
|
if tag == self.original:
|
||||||
|
tags.append(self.replacement)
|
||||||
|
else:
|
||||||
|
tags.append(tag)
|
||||||
|
|
||||||
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: ReplaceTagConfig):
|
||||||
|
return cls(original=config.original, replacement=config.replacement)
|
||||||
|
|
||||||
|
|
||||||
|
class MapTagValueConfig(BaseConfig):
|
||||||
|
name: Literal["map_tag_value"] = "map_tag_value"
|
||||||
|
tag_key: str
|
||||||
|
value_mapping: Dict[str, str]
|
||||||
|
target_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@_transforms.register(MapTagValueConfig)
|
||||||
|
class MapTagValue:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tag_key: str,
|
||||||
|
value_mapping: Dict[str, str],
|
||||||
|
target_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.tag_key = tag_key
|
||||||
|
self.value_mapping = value_mapping
|
||||||
|
self.target_key = target_key
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
tags = []
|
||||||
|
|
||||||
|
for tag in sound_event_annotation.tags:
|
||||||
|
if tag.key != self.tag_key:
|
||||||
|
tags.append(tag)
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = self.value_mapping.get(tag.value)
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
tags.append(tag)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.target_key is None:
|
||||||
|
tags.append(tag.model_copy(update=dict(value=value)))
|
||||||
|
else:
|
||||||
|
tags.append(
|
||||||
|
data.Tag(
|
||||||
|
key=self.target_key, # type: ignore
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: MapTagValueConfig):
|
||||||
|
return cls(
|
||||||
|
tag_key=config.tag_key,
|
||||||
|
value_mapping=config.value_mapping,
|
||||||
|
target_key=config.target_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyAllConfig(BaseConfig):
|
||||||
|
name: Literal["apply_all"] = "apply_all"
|
||||||
|
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@_transforms.register(ApplyAllConfig)
|
||||||
|
class ApplyAll:
|
||||||
|
def __init__(self, steps: List[SoundEventTransform]):
|
||||||
|
self.steps = steps
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
for step in self.steps:
|
||||||
|
sound_event_annotation = step(sound_event_annotation)
|
||||||
|
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: ApplyAllConfig):
|
||||||
|
steps = [build_sound_event_transform(step) for step in config.steps]
|
||||||
|
return cls(steps)
|
||||||
|
|
||||||
|
|
||||||
|
SoundEventTransformConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
SetFrequencyBoundConfig,
|
||||||
|
ReplaceTagConfig,
|
||||||
|
MapTagValueConfig,
|
||||||
|
ApplyIfConfig,
|
||||||
|
ApplyAllConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_sound_event_transform(
|
||||||
|
config: SoundEventTransformConfig,
|
||||||
|
) -> SoundEventTransform:
|
||||||
|
return _transforms.build(config)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_clip_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
transform: SoundEventTransform,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
transform(sound_event)
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
62
src/batdetect2/evaluate/dataframe.py
Normal file
62
src/batdetect2/evaluate/dataframe.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.typing.evaluate import MatchEvaluation
|
||||||
|
|
||||||
|
|
||||||
|
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
|
||||||
|
data = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||||
|
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||||
|
|
||||||
|
sound_event_annotation = match.sound_event_annotation
|
||||||
|
|
||||||
|
if sound_event_annotation is not None:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
assert geometry is not None
|
||||||
|
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||||
|
compute_bounds(geometry)
|
||||||
|
)
|
||||||
|
|
||||||
|
if match.pred_geometry is not None:
|
||||||
|
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||||
|
compute_bounds(match.pred_geometry)
|
||||||
|
)
|
||||||
|
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
("recording", "uuid"): match.clip.recording.uuid,
|
||||||
|
("clip", "uuid"): match.clip.uuid,
|
||||||
|
("clip", "start_time"): match.clip.start_time,
|
||||||
|
("clip", "end_time"): match.clip.end_time,
|
||||||
|
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||||
|
if match.sound_event_annotation is not None
|
||||||
|
else None,
|
||||||
|
("gt", "class"): match.gt_class,
|
||||||
|
("gt", "det"): match.gt_det,
|
||||||
|
("gt", "start_time"): gt_start_time,
|
||||||
|
("gt", "end_time"): gt_end_time,
|
||||||
|
("gt", "low_freq"): gt_low_freq,
|
||||||
|
("gt", "high_freq"): gt_high_freq,
|
||||||
|
("pred", "score"): match.pred_score,
|
||||||
|
("pred", "class"): match.pred_class,
|
||||||
|
("pred", "class_score"): match.pred_class_score,
|
||||||
|
("pred", "start_time"): pred_start_time,
|
||||||
|
("pred", "end_time"): pred_end_time,
|
||||||
|
("pred", "low_freq"): pred_low_freq,
|
||||||
|
("pred", "high_freq"): pred_high_freq,
|
||||||
|
("match", "affinity"): match.affinity,
|
||||||
|
**{
|
||||||
|
("pred_class_score", key): value
|
||||||
|
for key, value in match.pred_class_scores.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||||
|
return df
|
||||||
100
src/batdetect2/evaluate/evaluate.py
Normal file
100
src/batdetect2/evaluate/evaluate.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||||
|
from batdetect2.evaluate.match import match_all_predictions
|
||||||
|
from batdetect2.evaluate.metrics import (
|
||||||
|
ClassificationAccuracy,
|
||||||
|
ClassificationMeanAveragePrecision,
|
||||||
|
DetectionAveragePrecision,
|
||||||
|
)
|
||||||
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.plotting.clips import build_audio_loader
|
||||||
|
from batdetect2.postprocess import get_raw_predictions
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.train.config import FullTrainingConfig
|
||||||
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
|
from batdetect2.train.train import build_val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
model: Model,
|
||||||
|
test_annotations: List[data.ClipAnnotation],
|
||||||
|
config: Optional[FullTrainingConfig] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> Tuple[pd.DataFrame, dict]:
|
||||||
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader(config.preprocess.audio)
|
||||||
|
|
||||||
|
preprocessor = build_preprocessor(config.preprocess)
|
||||||
|
|
||||||
|
targets = build_targets(config.targets)
|
||||||
|
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
targets,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
config=config.train.labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = build_val_loader(
|
||||||
|
test_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config.train,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset: ValidationDataset = loader.dataset # type: ignore
|
||||||
|
|
||||||
|
clip_annotations = []
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
outputs = model.detector(batch.spec)
|
||||||
|
|
||||||
|
clip_annotations = [
|
||||||
|
dataset.clip_annotations[int(example_idx)]
|
||||||
|
for example_idx in batch.idx
|
||||||
|
]
|
||||||
|
|
||||||
|
predictions = get_raw_predictions(
|
||||||
|
outputs,
|
||||||
|
clips=[
|
||||||
|
clip_annotation.clip for clip_annotation in clip_annotations
|
||||||
|
],
|
||||||
|
targets=targets,
|
||||||
|
postprocessor=model.postprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_annotations.extend(clip_annotations)
|
||||||
|
predictions.extend(predictions)
|
||||||
|
|
||||||
|
matches = match_all_predictions(
|
||||||
|
clip_annotations,
|
||||||
|
predictions,
|
||||||
|
targets=targets,
|
||||||
|
config=config.evaluation.match,
|
||||||
|
)
|
||||||
|
|
||||||
|
df = extract_matches_dataframe(matches)
|
||||||
|
|
||||||
|
metrics = [
|
||||||
|
DetectionAveragePrecision(),
|
||||||
|
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
||||||
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
name: value
|
||||||
|
for metric in metrics
|
||||||
|
for name, value in metric(matches).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return df, results
|
||||||
@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
|
||||||
from typing import List, Literal, Optional, Protocol, Tuple
|
from typing import List, Literal, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,7 +8,6 @@ from soundevent import data
|
|||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from torch.multiprocessing import Pool
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
@ -284,7 +282,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
config = config or MatchConfig()
|
config = config or MatchConfig()
|
||||||
|
|
||||||
target_sound_events = [
|
target_sound_events = [
|
||||||
targets.transform(sound_event_annotation)
|
sound_event_annotation
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
if targets.filter(sound_event_annotation)
|
if targets.filter(sound_event_annotation)
|
||||||
and sound_event_annotation.sound_event.geometry is not None
|
and sound_event_annotation.sound_event.geometry is not None
|
||||||
@ -430,17 +428,19 @@ def match_all_predictions(
|
|||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
logger.info("Matching all annotations and predictions...")
|
logger.info("Matching all annotations and predictions...")
|
||||||
with Pool() as p:
|
return [
|
||||||
all_matches = p.starmap(
|
match
|
||||||
partial(
|
for clip_annotation, raw_predictions in zip(
|
||||||
match_sound_events_and_raw_predictions,
|
clip_annotations,
|
||||||
targets=targets,
|
predictions,
|
||||||
config=config,
|
|
||||||
),
|
|
||||||
zip(clip_annotations, predictions),
|
|
||||||
)
|
)
|
||||||
|
for match in match_sound_events_and_raw_predictions(
|
||||||
return [match for matches in all_matches for match in matches]
|
clip_annotation,
|
||||||
|
raw_predictions,
|
||||||
|
targets=targets,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -29,7 +29,6 @@ provided here.
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lightning import LightningModule
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
@ -68,7 +67,10 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
|||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.typing.models import DetectionModel
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
|
from batdetect2.typing.postprocess import (
|
||||||
|
DetectionsTensor,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
)
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
@ -102,7 +104,16 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Model(LightningModule):
|
class ModelConfig(BaseConfig):
|
||||||
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
@ -114,43 +125,39 @@ class Model(LightningModule):
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
config: ModelConfig,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.save_hyperparameters()
|
self.config = config
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
|
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(config: Optional[ModelConfig] = None):
|
def build_model(config: Optional[ModelConfig] = None):
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
|
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
config=config.model,
|
config=config.model,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
|
config=config,
|
||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
|||||||
@ -56,7 +56,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class SelfAttentionConfig(BaseConfig):
|
class SelfAttentionConfig(BaseConfig):
|
||||||
block_type: Literal["SelfAttention"] = "SelfAttention"
|
name: Literal["SelfAttention"] = "SelfAttention"
|
||||||
attention_channels: int
|
attention_channels: int
|
||||||
temperature: float = 1
|
temperature: float = 1
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class SelfAttention(nn.Module):
|
|||||||
class ConvConfig(BaseConfig):
|
class ConvConfig(BaseConfig):
|
||||||
"""Configuration for a basic ConvBlock."""
|
"""Configuration for a basic ConvBlock."""
|
||||||
|
|
||||||
block_type: Literal["ConvBlock"] = "ConvBlock"
|
name: Literal["ConvBlock"] = "ConvBlock"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -300,7 +300,7 @@ class VerticalConv(nn.Module):
|
|||||||
class FreqCoordConvDownConfig(BaseConfig):
|
class FreqCoordConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvDownBlock."""
|
"""Configuration for a FreqCoordConvDownBlock."""
|
||||||
|
|
||||||
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -390,7 +390,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
class StandardConvDownConfig(BaseConfig):
|
class StandardConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvDownBlock."""
|
"""Configuration for a StandardConvDownBlock."""
|
||||||
|
|
||||||
block_type: Literal["StandardConvDown"] = "StandardConvDown"
|
name: Literal["StandardConvDown"] = "StandardConvDown"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -460,7 +460,7 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
class FreqCoordConvUpConfig(BaseConfig):
|
class FreqCoordConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvUpBlock."""
|
"""Configuration for a FreqCoordConvUpBlock."""
|
||||||
|
|
||||||
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -569,7 +569,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
class StandardConvUpConfig(BaseConfig):
|
class StandardConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvUpBlock."""
|
"""Configuration for a StandardConvUpBlock."""
|
||||||
|
|
||||||
block_type: Literal["StandardConvUp"] = "StandardConvUp"
|
name: Literal["StandardConvUp"] = "StandardConvUp"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -664,13 +664,13 @@ LayerConfig = Annotated[
|
|||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
"LayerGroupConfig",
|
"LayerGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
class LayerGroupConfig(BaseConfig):
|
class LayerGroupConfig(BaseConfig):
|
||||||
block_type: Literal["LayerGroup"] = "LayerGroup"
|
name: Literal["LayerGroup"] = "LayerGroup"
|
||||||
layers: List[LayerConfig]
|
layers: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
@ -686,7 +686,7 @@ def build_layer_from_config(
|
|||||||
parameters derived from the config and the current pipeline state
|
parameters derived from the config and the current pipeline state
|
||||||
(`input_height`, `in_channels`).
|
(`input_height`, `in_channels`).
|
||||||
|
|
||||||
It uses the `block_type` field within the `config` object to determine
|
It uses the `name` field within the `config` object to determine
|
||||||
which block class to instantiate.
|
which block class to instantiate.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -698,7 +698,7 @@ def build_layer_from_config(
|
|||||||
config : LayerConfig
|
config : LayerConfig
|
||||||
A Pydantic configuration object for the desired block (e.g., an
|
A Pydantic configuration object for the desired block (e.g., an
|
||||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||||
by its `block_type` field.
|
by its `name` field.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -711,11 +711,11 @@ def build_layer_from_config(
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If the `config.block_type` does not correspond to a known block type.
|
If the `config.name` does not correspond to a known block type.
|
||||||
ValueError
|
ValueError
|
||||||
If parameters derived from the config are invalid for the block.
|
If parameters derived from the config are invalid for the block.
|
||||||
"""
|
"""
|
||||||
if config.block_type == "ConvBlock":
|
if config.name == "ConvBlock":
|
||||||
return (
|
return (
|
||||||
ConvBlock(
|
ConvBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -727,7 +727,7 @@ def build_layer_from_config(
|
|||||||
input_height,
|
input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "FreqCoordConvDown":
|
if config.name == "FreqCoordConvDown":
|
||||||
return (
|
return (
|
||||||
FreqCoordConvDownBlock(
|
FreqCoordConvDownBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -740,7 +740,7 @@ def build_layer_from_config(
|
|||||||
input_height // 2,
|
input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "StandardConvDown":
|
if config.name == "StandardConvDown":
|
||||||
return (
|
return (
|
||||||
StandardConvDownBlock(
|
StandardConvDownBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -752,7 +752,7 @@ def build_layer_from_config(
|
|||||||
input_height // 2,
|
input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "FreqCoordConvUp":
|
if config.name == "FreqCoordConvUp":
|
||||||
return (
|
return (
|
||||||
FreqCoordConvUpBlock(
|
FreqCoordConvUpBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -765,7 +765,7 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "StandardConvUp":
|
if config.name == "StandardConvUp":
|
||||||
return (
|
return (
|
||||||
StandardConvUpBlock(
|
StandardConvUpBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -777,7 +777,7 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "SelfAttention":
|
if config.name == "SelfAttention":
|
||||||
return (
|
return (
|
||||||
SelfAttention(
|
SelfAttention(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -788,7 +788,7 @@ def build_layer_from_config(
|
|||||||
input_height,
|
input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "LayerGroup":
|
if config.name == "LayerGroup":
|
||||||
current_channels = in_channels
|
current_channels = in_channels
|
||||||
current_height = input_height
|
current_height = input_height
|
||||||
|
|
||||||
@ -804,4 +804,4 @@ def build_layer_from_config(
|
|||||||
|
|
||||||
return nn.Sequential(*blocks), current_channels, current_height
|
return nn.Sequential(*blocks), current_channels, current_height
|
||||||
|
|
||||||
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
raise NotImplementedError(f"Unknown block type {config.name}")
|
||||||
|
|||||||
@ -128,7 +128,7 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
BottleneckLayerConfig = Annotated[
|
BottleneckLayerConfig = Annotated[
|
||||||
Union[SelfAttentionConfig,],
|
Union[SelfAttentionConfig,],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
|
|||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
|
|||||||
layers : List[DecoderLayerConfig]
|
layers : List[DecoderLayerConfig]
|
||||||
An ordered list of configuration objects, each defining one layer or
|
An ordered list of configuration objects, each defining one layer or
|
||||||
block in the decoder sequence. Each item must be a valid block
|
block in the decoder sequence. Each item must be a valid block
|
||||||
config including a `block_type` field and necessary parameters like
|
config including a `name` field and necessary parameters like
|
||||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||||
The list must contain at least one layer.
|
The list must contain at least one layer.
|
||||||
"""
|
"""
|
||||||
@ -249,9 +249,9 @@ def build_decoder(
|
|||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
configuration is invalid (e.g., empty list, unknown `name`).
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
If `build_layer_from_config` encounters an unknown `name`.
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_DECODER_CONFIG
|
config = config or DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
|
|||||||
An ordered list of configuration objects, each defining one layer or
|
An ordered list of configuration objects, each defining one layer or
|
||||||
block in the encoder sequence. Each item must be a valid block config
|
block in the encoder sequence. Each item must be a valid block config
|
||||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||||
`StandardConvDownConfig`) including a `block_type` field and necessary
|
`StandardConvDownConfig`) including a `name` field and necessary
|
||||||
parameters like `out_channels`. Input channels for each layer are
|
parameters like `out_channels`. Input channels for each layer are
|
||||||
inferred sequentially. The list must contain at least one layer.
|
inferred sequentially. The list must contain at least one layer.
|
||||||
"""
|
"""
|
||||||
@ -287,9 +287,9 @@ def build_encoder(
|
|||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
configuration is invalid (e.g., empty list, unknown `name`).
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
If `build_layer_from_config` encounters an unknown `name`.
|
||||||
"""
|
"""
|
||||||
if in_channels <= 0 or input_height <= 0:
|
if in_channels <= 0 or input_height <= 0:
|
||||||
raise ValueError("in_channels and input_height must be positive.")
|
raise ValueError("in_channels and input_height must be positive.")
|
||||||
|
|||||||
@ -1,6 +1,14 @@
|
|||||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||||
|
|
||||||
from typing import Annotated, Callable, List, Literal, Optional, Union
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -306,7 +314,7 @@ class SpectrogramConfig(BaseConfig):
|
|||||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||||
transforms: List[SpectrogramTransform] = Field(
|
transforms: Sequence[SpectrogramTransform] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
PcenConfig(),
|
PcenConfig(),
|
||||||
SpectralMeanSubstractionConfig(),
|
SpectralMeanSubstractionConfig(),
|
||||||
|
|||||||
@ -1,53 +1,26 @@
|
|||||||
"""Main entry point for the BatDetect2 Target Definition subsystem.
|
"""BatDetect2 Target Definition system."""
|
||||||
|
|
||||||
This package (`batdetect2.targets`) provides the tools and configurations
|
|
||||||
necessary to define precisely what the BatDetect2 model should learn to detect,
|
|
||||||
classify, and localize from audio data. It involves several conceptual steps,
|
|
||||||
managed through configuration files and culminating in an executable pipeline:
|
|
||||||
|
|
||||||
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
|
|
||||||
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
|
|
||||||
3. **Transformation (`.transform`)**: Modifying tags (standardization,
|
|
||||||
derivation).
|
|
||||||
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
|
|
||||||
target position and size representations, and back.
|
|
||||||
5. **Class Definition (`.classes`)**: Mapping tags to target class names
|
|
||||||
(encoding) and mapping predicted names back to tags (decoding).
|
|
||||||
|
|
||||||
This module exposes the key components for users to configure and utilize this
|
|
||||||
target definition pipeline, primarily through the `TargetConfig` data structure
|
|
||||||
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
|
|
||||||
configured processing steps. The main way to create a functional `Targets`
|
|
||||||
object is via the `build_targets` or `load_targets` functions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field, field_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.data.conditions import (
|
||||||
|
SoundEventCondition,
|
||||||
|
build_sound_event_condition,
|
||||||
|
)
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
ClassesConfig,
|
DEFAULT_CLASSES,
|
||||||
|
DEFAULT_GENERIC_CLASS,
|
||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
TargetClass,
|
TargetClassConfig,
|
||||||
build_generic_class_tags,
|
|
||||||
build_sound_event_decoder,
|
build_sound_event_decoder,
|
||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
load_classes_config,
|
|
||||||
load_decoder_from_config,
|
|
||||||
load_encoder_from_config,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.filtering import (
|
|
||||||
FilterConfig,
|
|
||||||
FilterRule,
|
|
||||||
SoundEventFilter,
|
|
||||||
build_sound_event_filter,
|
|
||||||
load_filter_config,
|
|
||||||
load_filter_from_config,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
@ -55,114 +28,53 @@ from batdetect2.targets.rois import (
|
|||||||
ROITargetMapper,
|
ROITargetMapper,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import call_type, individual
|
||||||
TagInfo,
|
|
||||||
TermInfo,
|
|
||||||
TermRegistry,
|
|
||||||
call_type,
|
|
||||||
default_term_registry,
|
|
||||||
get_tag_from_info,
|
|
||||||
get_term_from_key,
|
|
||||||
individual,
|
|
||||||
register_term,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.transform import (
|
|
||||||
DerivationRegistry,
|
|
||||||
DeriveTagRule,
|
|
||||||
MapValueRule,
|
|
||||||
ReplaceRule,
|
|
||||||
SoundEventTransformation,
|
|
||||||
TransformConfig,
|
|
||||||
build_transformation_from_config,
|
|
||||||
default_derivation_registry,
|
|
||||||
get_derivation,
|
|
||||||
load_transformation_config,
|
|
||||||
load_transformation_from_config,
|
|
||||||
register_derivation,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassesConfig",
|
|
||||||
"DEFAULT_TARGET_CONFIG",
|
"DEFAULT_TARGET_CONFIG",
|
||||||
"DeriveTagRule",
|
|
||||||
"FilterConfig",
|
|
||||||
"FilterRule",
|
|
||||||
"MapValueRule",
|
|
||||||
"AnchorBBoxMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"ReplaceRule",
|
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
"SoundEventFilter",
|
"TargetClassConfig",
|
||||||
"SoundEventTransformation",
|
|
||||||
"TagInfo",
|
|
||||||
"TargetClass",
|
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"Targets",
|
"Targets",
|
||||||
"TermInfo",
|
|
||||||
"TransformConfig",
|
|
||||||
"build_generic_class_tags",
|
|
||||||
"build_roi_mapper",
|
"build_roi_mapper",
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
"build_sound_event_filter",
|
|
||||||
"build_transformation_from_config",
|
|
||||||
"call_type",
|
"call_type",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
"get_derivation",
|
|
||||||
"get_tag_from_info",
|
|
||||||
"get_term_from_key",
|
|
||||||
"individual",
|
"individual",
|
||||||
"load_classes_config",
|
|
||||||
"load_decoder_from_config",
|
|
||||||
"load_encoder_from_config",
|
|
||||||
"load_filter_config",
|
|
||||||
"load_filter_from_config",
|
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
"load_transformation_config",
|
|
||||||
"load_transformation_from_config",
|
|
||||||
"register_derivation",
|
|
||||||
"register_term",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TargetConfig(BaseConfig):
|
class TargetConfig(BaseConfig):
|
||||||
"""Unified configuration for the entire target definition pipeline.
|
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS)
|
||||||
|
|
||||||
This model aggregates the configurations for semantic processing (filtering,
|
classification_targets: List[TargetClassConfig] = Field(
|
||||||
transformation, class definition) and geometric processing (ROI mapping).
|
default_factory=lambda: DEFAULT_CLASSES
|
||||||
It serves as the primary input for building a complete `Targets` object
|
|
||||||
via `build_targets` or `load_targets`.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
filtering : FilterConfig, optional
|
|
||||||
Configuration for filtering sound event annotations based on tags.
|
|
||||||
If None or omitted, no filtering is applied.
|
|
||||||
transforms : TransformConfig, optional
|
|
||||||
Configuration for transforming annotation tags
|
|
||||||
(mapping, derivation, etc.). If None or omitted, no tag transformations
|
|
||||||
are applied.
|
|
||||||
classes : ClassesConfig
|
|
||||||
Configuration defining the specific target classes, their tag matching
|
|
||||||
rules for encoding, their representative tags for decoding
|
|
||||||
(`output_tags`), and the definition of the generic class tags.
|
|
||||||
This section is mandatory.
|
|
||||||
roi : ROIConfig, optional
|
|
||||||
Configuration defining how geometric ROIs (e.g., bounding boxes) are
|
|
||||||
mapped to target representations (reference point, scaled size).
|
|
||||||
Controls `position`, `time_scale`, `frequency_scale`. If None or
|
|
||||||
omitted, default ROI mapping settings are used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
filtering: FilterConfig = Field(default_factory=FilterConfig)
|
|
||||||
transforms: TransformConfig = Field(default_factory=TransformConfig)
|
|
||||||
classes: ClassesConfig = Field(
|
|
||||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
|
||||||
)
|
)
|
||||||
|
|
||||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||||
|
|
||||||
|
@field_validator("classification_targets")
|
||||||
|
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||||
|
"""Ensure all defined class names are unique."""
|
||||||
|
names = [c.name for c in v]
|
||||||
|
|
||||||
|
if len(names) != len(set(names)):
|
||||||
|
name_counts = Counter(names)
|
||||||
|
duplicates = [
|
||||||
|
name for name, count in name_counts.items() if count > 1
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
"Class names must be unique. Found duplicates: "
|
||||||
|
f"{', '.join(duplicates)}"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
@ -238,8 +150,7 @@ class Targets(TargetProtocol):
|
|||||||
roi_mapper: ROITargetMapper,
|
roi_mapper: ROITargetMapper,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
filter_fn: Optional[SoundEventFilter] = None,
|
filter_fn: Optional[SoundEventCondition] = None,
|
||||||
transform_fn: Optional[SoundEventTransformation] = None,
|
|
||||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the Targets object.
|
"""Initialize the Targets object.
|
||||||
@ -272,7 +183,6 @@ class Targets(TargetProtocol):
|
|||||||
self._filter_fn = filter_fn
|
self._filter_fn = filter_fn
|
||||||
self._encode_fn = encode_fn
|
self._encode_fn = encode_fn
|
||||||
self._decode_fn = decode_fn
|
self._decode_fn = decode_fn
|
||||||
self._transform_fn = transform_fn
|
|
||||||
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
||||||
|
|
||||||
for class_name in self._roi_mapper_overrides:
|
for class_name in self._roi_mapper_overrides:
|
||||||
@ -344,27 +254,6 @@ class Targets(TargetProtocol):
|
|||||||
"""
|
"""
|
||||||
return self._decode_fn(class_label)
|
return self._decode_fn(class_label)
|
||||||
|
|
||||||
def transform(
|
|
||||||
self, sound_event: data.SoundEventAnnotation
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
"""Apply the configured tag transformations to an annotation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation whose tags should be transformed.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.SoundEventAnnotation
|
|
||||||
A new annotation object with the transformed tags. If no
|
|
||||||
transformations were configured, the original annotation object is
|
|
||||||
returned.
|
|
||||||
"""
|
|
||||||
if self._transform_fn:
|
|
||||||
return self._transform_fn(sound_event)
|
|
||||||
return sound_event
|
|
||||||
|
|
||||||
def encode_roi(
|
def encode_roi(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> tuple[Position, Size]:
|
) -> tuple[Position, Size]:
|
||||||
@ -430,113 +319,14 @@ class Targets(TargetProtocol):
|
|||||||
return self._roi_mapper.decode(position, size)
|
return self._roi_mapper.decode(position, size)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSES = [
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis mystacinus")],
|
|
||||||
name="myomys",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis alcathoe")],
|
|
||||||
name="myoalc",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Eptesicus serotinus")],
|
|
||||||
name="eptser",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus nathusii")],
|
|
||||||
name="pipnat",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Barbastellus barbastellus")],
|
|
||||||
name="barbar",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis nattereri")],
|
|
||||||
name="myonat",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis daubentonii")],
|
|
||||||
name="myodau",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis brandtii")],
|
|
||||||
name="myobra",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
|
||||||
name="pippip",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis bechsteinii")],
|
|
||||||
name="myobec",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
|
||||||
name="pippyg",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
|
||||||
name="rhihip",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
|
||||||
name="nyclei",
|
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
|
||||||
name="rhifer",
|
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Plecotus auritus")],
|
|
||||||
name="pleaur",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Nyctalus noctula")],
|
|
||||||
name="nycnoc",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Plecotus austriacus")],
|
|
||||||
name="pleaus",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
|
|
||||||
classes=DEFAULT_CLASSES,
|
|
||||||
generic_class=[TagInfo(value="Bat")],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||||
filtering=FilterConfig(
|
classification_targets=DEFAULT_CLASSES,
|
||||||
rules=[
|
detection_target=DEFAULT_GENERIC_CLASS,
|
||||||
FilterRule(
|
|
||||||
match_type="all",
|
|
||||||
tags=[TagInfo(key="event", value="Echolocation")],
|
|
||||||
),
|
|
||||||
FilterRule(
|
|
||||||
match_type="exclude",
|
|
||||||
tags=[
|
|
||||||
TagInfo(key="event", value="Feeding"),
|
|
||||||
TagInfo(key="event", value="Unknown"),
|
|
||||||
TagInfo(key="event", value="Not Bat"),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
classes=DEFAULT_CLASSES_CONFIG,
|
|
||||||
roi=AnchorBBoxMapperConfig(),
|
roi=AnchorBBoxMapperConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_targets(
|
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||||
config: Optional[TargetConfig] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
|
||||||
) -> Targets:
|
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
This factory function takes the unified `TargetConfig` and constructs all
|
This factory function takes the unified `TargetConfig` and constructs all
|
||||||
@ -550,13 +340,6 @@ def build_targets(
|
|||||||
----------
|
----------
|
||||||
config : TargetConfig
|
config : TargetConfig
|
||||||
The loaded and validated unified target configuration object.
|
The loaded and validated unified target configuration object.
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use for resolving term keys. Defaults
|
|
||||||
to the global `batdetect2.targets.terms.term_registry`.
|
|
||||||
derivation_registry : DerivationRegistry, optional
|
|
||||||
The DerivationRegistry instance to use for resolving derivation
|
|
||||||
function names. Defaults to the global
|
|
||||||
`batdetect2.targets.transform.derivation_registry`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -577,40 +360,18 @@ def build_targets(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
filter_fn = (
|
filter_fn = build_sound_event_condition(config.detection_target.match_if)
|
||||||
build_sound_event_filter(
|
encode_fn = build_sound_event_encoder(config.classification_targets)
|
||||||
config.filtering,
|
decode_fn = build_sound_event_decoder(config.classification_targets)
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
if config.filtering
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
encode_fn = build_sound_event_encoder(
|
|
||||||
config.classes,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
decode_fn = build_sound_event_decoder(
|
|
||||||
config.classes,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
transform_fn = (
|
|
||||||
build_transformation_from_config(
|
|
||||||
config.transforms,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
if config.transforms
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
roi_mapper = build_roi_mapper(config.roi)
|
roi_mapper = build_roi_mapper(config.roi)
|
||||||
class_names = get_class_names_from_config(config.classes)
|
class_names = get_class_names_from_config(config.classification_targets)
|
||||||
generic_class_tags = build_generic_class_tags(
|
|
||||||
config.classes,
|
generic_class_tags = config.detection_target.assign_tags
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
roi_overrides = {
|
roi_overrides = {
|
||||||
class_config.name: build_roi_mapper(class_config.roi)
|
class_config.name: build_roi_mapper(class_config.roi)
|
||||||
for class_config in config.classes.classes
|
for class_config in config.classification_targets
|
||||||
if class_config.roi is not None
|
if class_config.roi is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -621,7 +382,6 @@ def build_targets(
|
|||||||
class_names=class_names,
|
class_names=class_names,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
generic_class_tags=generic_class_tags,
|
generic_class_tags=generic_class_tags,
|
||||||
transform_fn=transform_fn,
|
|
||||||
roi_mapper_overrides=roi_overrides,
|
roi_mapper_overrides=roi_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -629,8 +389,6 @@ def build_targets(
|
|||||||
def load_targets(
|
def load_targets(
|
||||||
config_path: data.PathLike,
|
config_path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Load a Targets object directly from a configuration file.
|
"""Load a Targets object directly from a configuration file.
|
||||||
|
|
||||||
@ -645,11 +403,6 @@ def load_targets(
|
|||||||
field : str, optional
|
field : str, optional
|
||||||
Dot-separated path to a nested section within the file containing
|
Dot-separated path to a nested section within the file containing
|
||||||
the target configuration. If None, the entire file content is used.
|
the target configuration. If None, the entire file content is used.
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use. Defaults to the global default.
|
|
||||||
derivation_registry : DerivationRegistry, optional
|
|
||||||
The DerivationRegistry instance to use. Defaults to the global
|
|
||||||
default.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -670,11 +423,7 @@ def load_targets(
|
|||||||
config_path,
|
config_path,
|
||||||
field=field,
|
field=field,
|
||||||
)
|
)
|
||||||
return build_targets(
|
return build_targets(config)
|
||||||
config,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_encoded_sound_events(
|
def iterate_encoded_sound_events(
|
||||||
@ -690,8 +439,6 @@ def iterate_encoded_sound_events(
|
|||||||
if geometry is None:
|
if geometry is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sound_event = targets.transform(sound_event)
|
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event)
|
class_name = targets.encode_class(sound_event)
|
||||||
position, size = targets.encode_roi(sound_event)
|
position, size = targets.encode_roi(sound_event)
|
||||||
|
|
||||||
|
|||||||
@ -1,253 +1,172 @@
|
|||||||
from collections import Counter
|
from typing import Dict, List, Optional
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple
|
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.targets.rois import ROIMapperConfig
|
from batdetect2.data.conditions import (
|
||||||
from batdetect2.targets.terms import (
|
AllOfConfig,
|
||||||
GENERIC_CLASS_KEY,
|
HasAllTagsConfig,
|
||||||
TagInfo,
|
HasAnyTagConfig,
|
||||||
TermRegistry,
|
HasTagConfig,
|
||||||
default_term_registry,
|
NotConfig,
|
||||||
get_tag_from_info,
|
SoundEventCondition,
|
||||||
|
SoundEventConditionConfig,
|
||||||
|
build_sound_event_condition,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
||||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_SPECIES_LIST",
|
|
||||||
"build_generic_class_tags",
|
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
"load_classes_config",
|
|
||||||
"load_decoder_from_config",
|
|
||||||
"load_encoder_from_config",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_SPECIES_LIST = [
|
class TargetClassConfig(BaseConfig):
|
||||||
"Barbastella barbastellus",
|
"""Defines a target class of sound events."""
|
||||||
"Eptesicus serotinus",
|
|
||||||
"Myotis alcathoe",
|
|
||||||
"Myotis bechsteinii",
|
|
||||||
"Myotis brandtii",
|
|
||||||
"Myotis daubentonii",
|
|
||||||
"Myotis mystacinus",
|
|
||||||
"Myotis nattereri",
|
|
||||||
"Nyctalus leisleri",
|
|
||||||
"Nyctalus noctula",
|
|
||||||
"Pipistrellus nathusii",
|
|
||||||
"Pipistrellus pipistrellus",
|
|
||||||
"Pipistrellus pygmaeus",
|
|
||||||
"Plecotus auritus",
|
|
||||||
"Plecotus austriacus",
|
|
||||||
"Rhinolophus ferrumequinum",
|
|
||||||
"Rhinolophus hipposideros",
|
|
||||||
]
|
|
||||||
"""A default list of common bat species names found in the UK."""
|
|
||||||
|
|
||||||
|
|
||||||
class TargetClass(BaseConfig):
|
|
||||||
"""Defines criteria for encoding annotations and decoding predictions.
|
|
||||||
|
|
||||||
Each instance represents one potential output class for the classification
|
|
||||||
model. It specifies:
|
|
||||||
1. A unique `name` for the class.
|
|
||||||
2. The tag conditions (`tags` and `match_type`) an annotation must meet to
|
|
||||||
be assigned this class name during training data preparation (encoding).
|
|
||||||
3. An optional, alternative set of tags (`output_tags`) to be used when
|
|
||||||
converting a model's prediction of this class name back into annotation
|
|
||||||
tags (decoding).
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
The unique name assigned to this target class (e.g., 'pippip',
|
|
||||||
'myodau', 'noise'). This name is used as the label during model
|
|
||||||
training and is the expected output from the model's prediction.
|
|
||||||
Should be unique across all TargetClass definitions in a configuration.
|
|
||||||
tags : List[TagInfo]
|
|
||||||
A list of one or more tags (defined using `TagInfo`) used to identify
|
|
||||||
if an existing annotation belongs to this class during encoding (data
|
|
||||||
preparation for training). The `match_type` attribute determines how
|
|
||||||
these tags are evaluated.
|
|
||||||
match_type : Literal["all", "any"], default="all"
|
|
||||||
Determines how the `tags` list is evaluated during encoding:
|
|
||||||
- "all": The annotation must have *all* the tags listed to match.
|
|
||||||
- "any": The annotation must have *at least one* of the tags listed
|
|
||||||
to match.
|
|
||||||
output_tags: Optional[List[TagInfo]], default=None
|
|
||||||
An optional list of tags (defined using `TagInfo`) to be assigned to a
|
|
||||||
new annotation when the model predicts this class `name`. If `None`
|
|
||||||
(default), the tags listed in the `tags` field will be used for
|
|
||||||
decoding. If provided, this list overrides the `tags` field for the
|
|
||||||
purpose of decoding predictions back into meaningful annotation tags.
|
|
||||||
This allows, for example, training on broader categories but decoding
|
|
||||||
to more specific representative tags.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
tags: List[TagInfo] = Field(min_length=1)
|
|
||||||
match_type: Literal["all", "any"] = Field(default="all")
|
|
||||||
output_tags: Optional[List[TagInfo]] = None
|
|
||||||
roi: Optional[ROIMapperConfig] = None
|
|
||||||
|
|
||||||
|
condition_input: Optional[SoundEventConditionConfig] = Field(
|
||||||
def _get_default_classes() -> List[TargetClass]:
|
alias="match_if",
|
||||||
"""Generate a list of default target classes.
|
default=None,
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[TargetClass]
|
|
||||||
A list of TargetClass objects, one for each species in
|
|
||||||
DEFAULT_SPECIES_LIST. The class names are simplified versions of the
|
|
||||||
species names.
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
TargetClass(
|
|
||||||
name=_get_default_class_name(value),
|
|
||||||
tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)],
|
|
||||||
)
|
|
||||||
for value in DEFAULT_SPECIES_LIST
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_class_name(species: str) -> str:
|
|
||||||
"""Generate a default class name from a species name.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
species : str
|
|
||||||
The species name (e.g., "Myotis daubentonii").
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str
|
|
||||||
A simplified class name (e.g., "myodau").
|
|
||||||
The genus and species names are converted to lowercase,
|
|
||||||
the first three letters of each are taken, and concatenated.
|
|
||||||
"""
|
|
||||||
genus, species = species.strip().split(" ")
|
|
||||||
return f"{genus.lower()[:3]}{species.lower()[:3]}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_generic_class() -> List[TagInfo]:
|
|
||||||
"""Generate the default list of TagInfo objects for the generic class.
|
|
||||||
|
|
||||||
Provides a default set of tags used to represent the generic "Bat" category
|
|
||||||
when decoding predictions that didn't match a specific class.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[TagInfo]
|
|
||||||
A list containing default TagInfo objects, typically representing
|
|
||||||
`call_type: Echolocation` and `order: Chiroptera`.
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
TagInfo(key="call_type", value="Echolocation"),
|
|
||||||
TagInfo(key="order", value="Chiroptera"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ClassesConfig(BaseConfig):
|
|
||||||
"""Configuration defining target classes and the generic fallback category.
|
|
||||||
|
|
||||||
Holds the ordered list of specific target class definitions (`TargetClass`)
|
|
||||||
and defines the tags representing the generic category for sounds that pass
|
|
||||||
filtering but do not match any specific class.
|
|
||||||
|
|
||||||
The order of `TargetClass` objects in the `classes` list defines the
|
|
||||||
priority for classification during encoding. The system checks annotations
|
|
||||||
against these definitions sequentially and assigns the name of the *first*
|
|
||||||
matching class.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
classes : List[TargetClass]
|
|
||||||
An ordered list of specific target class definitions. The order
|
|
||||||
determines matching priority (first match wins). Defaults to a
|
|
||||||
standard set of classes via `get_default_classes`.
|
|
||||||
generic_class : List[TagInfo]
|
|
||||||
A list of tags defining the "generic" or "unclassified but relevant"
|
|
||||||
category (e.g., representing a generic 'Bat' call that wasn't
|
|
||||||
assigned to a specific species). These tags are typically assigned
|
|
||||||
during decoding when a sound event was detected and passed filtering
|
|
||||||
but did not match any specific class rule defined in the `classes` list.
|
|
||||||
Defaults to a standard set of tags via `get_default_generic_class`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If validation fails (e.g., non-unique class names in the `classes`
|
|
||||||
list).
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
- It is crucial that the `name` attribute of each `TargetClass` in the
|
|
||||||
`classes` list is unique. This configuration includes a validator to
|
|
||||||
enforce this uniqueness.
|
|
||||||
- The `generic_class` tags provide a baseline identity for relevant sounds
|
|
||||||
that don't fit into more specific defined categories.
|
|
||||||
"""
|
|
||||||
|
|
||||||
classes: List[TargetClass] = Field(default_factory=_get_default_classes)
|
|
||||||
|
|
||||||
generic_class: List[TagInfo] = Field(
|
|
||||||
default_factory=_get_default_generic_class
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("classes")
|
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
|
||||||
def check_unique_class_names(cls, v: List[TargetClass]):
|
|
||||||
"""Ensure all defined class names are unique."""
|
|
||||||
names = [c.name for c in v]
|
|
||||||
|
|
||||||
if len(names) != len(set(names)):
|
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||||
name_counts = Counter(names)
|
|
||||||
duplicates = [
|
roi: Optional[ROIMapperConfig] = None
|
||||||
name for name, count in name_counts.items() if count > 1
|
|
||||||
]
|
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _process_tags(self) -> "TargetClassConfig":
|
||||||
|
if self.tags and self.condition_input:
|
||||||
|
raise ValueError("Use either 'tags' or 'match_if', not both.")
|
||||||
|
|
||||||
|
if self.condition_input is not None:
|
||||||
|
self._match_if = self.condition_input
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self.tags is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Class names must be unique. Found duplicates: "
|
f"Class '{self.name}' must have a 'tags' or 'match_if' rule."
|
||||||
f"{', '.join(duplicates)}"
|
|
||||||
)
|
)
|
||||||
return v
|
|
||||||
|
self._match_if = HasAllTagsConfig(tags=self.tags)
|
||||||
|
|
||||||
|
if not self.assign_tags:
|
||||||
|
self.assign_tags = self.tags
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def match_if(self) -> SoundEventConditionConfig:
|
||||||
|
return self._match_if
|
||||||
|
|
||||||
|
|
||||||
def is_target_class(
|
DEFAULT_GENERIC_CLASS = TargetClassConfig(
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
name="bat",
|
||||||
tags: Set[data.Tag],
|
match_if=AllOfConfig(
|
||||||
match_all: bool = True,
|
conditions=[
|
||||||
) -> bool:
|
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")),
|
||||||
"""Check if a sound event annotation matches a set of required tags.
|
NotConfig(
|
||||||
|
condition=HasAnyTagConfig(
|
||||||
Parameters
|
tags=[
|
||||||
----------
|
data.Tag(key="event", value="Feeding"),
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
data.Tag(key="event", value="Unknown"),
|
||||||
The annotation to check.
|
data.Tag(key="event", value="Not Bat"),
|
||||||
required_tags : Set[data.Tag]
|
]
|
||||||
A set of `soundevent.data.Tag` objects that define the class criteria.
|
)
|
||||||
match_all : bool, default=True
|
),
|
||||||
If True, checks if *all* `required_tags` are present in the
|
]
|
||||||
annotation's tags (subset check). If False, checks if *at least one*
|
),
|
||||||
of the `required_tags` is present (intersection check).
|
assign_tags=[
|
||||||
|
data.Tag(key="call_type", value="Echolocation"),
|
||||||
Returns
|
data.Tag(key="order", value="Chiroptera"),
|
||||||
-------
|
],
|
||||||
bool
|
)
|
||||||
True if the annotation meets the tag criteria, False otherwise.
|
|
||||||
"""
|
|
||||||
annotation_tags = set(sound_event_annotation.tags)
|
|
||||||
|
|
||||||
if match_all:
|
|
||||||
return tags <= annotation_tags
|
|
||||||
|
|
||||||
return bool(tags & annotation_tags)
|
|
||||||
|
|
||||||
|
|
||||||
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
DEFAULT_CLASSES = [
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myomys",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis mystacinus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myoalc",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis alcathoe")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="eptser",
|
||||||
|
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pipnat",
|
||||||
|
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="barbar",
|
||||||
|
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myonat",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis nattereri")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myodau",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis daubentonii")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myobra",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis brandtii")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pippip",
|
||||||
|
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="myobec",
|
||||||
|
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pippyg",
|
||||||
|
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="rhihip",
|
||||||
|
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="nyclei",
|
||||||
|
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="rhifer",
|
||||||
|
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pleaur",
|
||||||
|
tags=[data.Tag(key="class", value="Plecotus auritus")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="nycnoc",
|
||||||
|
tags=[data.Tag(key="class", value="Nyctalus noctula")],
|
||||||
|
),
|
||||||
|
TargetClassConfig(
|
||||||
|
name="pleaus",
|
||||||
|
tags=[data.Tag(key="class", value="Plecotus austriacus")],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_names_from_config(configs: List[TargetClassConfig]) -> List[str]:
|
||||||
"""Extract the list of class names from a ClassesConfig object.
|
"""Extract the list of class names from a ClassesConfig object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -260,340 +179,60 @@ def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
|||||||
List[str]
|
List[str]
|
||||||
An ordered list of unique class names defined in the configuration.
|
An ordered list of unique class names defined in the configuration.
|
||||||
"""
|
"""
|
||||||
return [class_info.name for class_info in config.classes]
|
return [class_info.name for class_info in configs]
|
||||||
|
|
||||||
|
|
||||||
def _encode_with_multiple_classifiers(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]],
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Encode an annotation by checking against a list of classifiers.
|
|
||||||
|
|
||||||
Internal helper function used by the `SoundEventEncoder`. It iterates
|
|
||||||
through the provided list of (class_name, classifier_function) pairs.
|
|
||||||
Returns the name associated with the first classifier function that
|
|
||||||
returns True for the given annotation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to encode.
|
|
||||||
classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]]
|
|
||||||
An ordered list where each tuple contains a class name and a function
|
|
||||||
that returns True if the annotation matches that class. The order
|
|
||||||
determines priority.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str or None
|
|
||||||
The name of the first matching class, or None if no classifier matches.
|
|
||||||
"""
|
|
||||||
for class_name, classifier in classifiers:
|
|
||||||
if classifier(sound_event_annotation):
|
|
||||||
return class_name
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_encoder(
|
def build_sound_event_encoder(
|
||||||
config: ClassesConfig,
|
configs: List[TargetClassConfig],
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
"""Build a sound event encoder function from the classes configuration.
|
"""Build a sound event encoder function from the classes configuration."""
|
||||||
|
conditions = {
|
||||||
|
class_config.name: build_sound_event_condition(class_config.match_if)
|
||||||
|
for class_config in configs
|
||||||
|
}
|
||||||
|
|
||||||
The returned encoder function iterates through the class definitions in the
|
return SoundEventClassifier(conditions)
|
||||||
order specified in the config. It assigns an annotation the name of the
|
|
||||||
first class definition it matches.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : ClassesConfig
|
|
||||||
The loaded and validated classes configuration object.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance used to look up term keys specified in the
|
|
||||||
`TagInfo` objects within the configuration. Defaults to the global
|
|
||||||
`batdetect2.targets.terms.registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventEncoder
|
|
||||||
A callable function that takes a `SoundEventAnnotation` and returns
|
|
||||||
an optional string representing the matched class name, or None if no
|
|
||||||
class matches.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a term key specified in the configuration is not found in the
|
|
||||||
provided `term_registry`.
|
|
||||||
"""
|
|
||||||
binary_classifiers = [
|
|
||||||
(
|
|
||||||
class_info.name,
|
|
||||||
partial(
|
|
||||||
is_target_class,
|
|
||||||
tags={
|
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
|
||||||
for tag_info in class_info.tags
|
|
||||||
},
|
|
||||||
match_all=class_info.match_type == "all",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for class_info in config.classes
|
|
||||||
]
|
|
||||||
|
|
||||||
return partial(
|
|
||||||
_encode_with_multiple_classifiers,
|
|
||||||
classifiers=binary_classifiers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_class(
|
class SoundEventClassifier:
|
||||||
name: str,
|
def __init__(self, mapping: Dict[str, SoundEventCondition]):
|
||||||
mapping: Dict[str, List[data.Tag]],
|
self.mapping = mapping
|
||||||
raise_on_error: bool = True,
|
|
||||||
) -> List[data.Tag]:
|
|
||||||
"""Decode a class name into a list of representative tags using a mapping.
|
|
||||||
|
|
||||||
Internal helper function used by the `SoundEventDecoder`. Looks up the
|
def __call__(
|
||||||
provided class `name` in the `mapping` dictionary.
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> Optional[str]:
|
||||||
Parameters
|
for name, condition in self.mapping.items():
|
||||||
----------
|
if condition(sound_event_annotation):
|
||||||
name : str
|
return name
|
||||||
The class name to decode.
|
|
||||||
mapping : Dict[str, List[data.Tag]]
|
|
||||||
A dictionary mapping class names to lists of `soundevent.data.Tag`
|
|
||||||
objects.
|
|
||||||
raise_on_error : bool, default=True
|
|
||||||
If True, raises a ValueError if the `name` is not found in the
|
|
||||||
`mapping`. If False, returns an empty list if the `name` is not found.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.Tag]
|
|
||||||
The list of tags associated with the class name, or an empty list if
|
|
||||||
not found and `raise_on_error` is False.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `name` is not found in `mapping` and `raise_on_error` is True.
|
|
||||||
"""
|
|
||||||
if name not in mapping and raise_on_error:
|
|
||||||
raise ValueError(f"Class {name} not found in mapping.")
|
|
||||||
|
|
||||||
if name not in mapping:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return mapping[name]
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_decoder(
|
def build_sound_event_decoder(
|
||||||
config: ClassesConfig,
|
configs: List[TargetClassConfig],
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Build a sound event decoder function from the classes configuration.
|
"""Build a sound event decoder function from the classes configuration."""
|
||||||
|
mapping = {
|
||||||
Creates a callable `SoundEventDecoder` that maps a class name string
|
class_config.name: class_config.assign_tags for class_config in configs
|
||||||
back to a list of representative `soundevent.data.Tag` objects based on
|
}
|
||||||
the `ClassesConfig`. It uses the `output_tags` field if provided in a
|
return TagDecoder(mapping, raise_on_unknown=raise_on_unmapped)
|
||||||
`TargetClass`, otherwise falls back to the `tags` field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : ClassesConfig
|
|
||||||
The loaded and validated classes configuration object.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance used to look up term keys. Defaults to the
|
|
||||||
global `batdetect2.targets.terms.registry`.
|
|
||||||
raise_on_unmapped : bool, default=False
|
|
||||||
If True, the returned decoder function will raise a ValueError if asked
|
|
||||||
to decode a class name that is not in the configuration. If False, it
|
|
||||||
will return an empty list for unmapped names.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventDecoder
|
|
||||||
A callable function that takes a class name string and returns a list
|
|
||||||
of `soundevent.data.Tag` objects.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a term key specified in the configuration (`output_tags`, `tags`, or
|
|
||||||
`generic_class`) is not found in the provided `term_registry`.
|
|
||||||
"""
|
|
||||||
mapping = {}
|
|
||||||
for class_info in config.classes:
|
|
||||||
tags_to_use = (
|
|
||||||
class_info.output_tags
|
|
||||||
if class_info.output_tags is not None
|
|
||||||
else class_info.tags
|
|
||||||
)
|
|
||||||
mapping[class_info.name] = [
|
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
|
||||||
for tag_info in tags_to_use
|
|
||||||
]
|
|
||||||
|
|
||||||
return partial(
|
|
||||||
_decode_class,
|
|
||||||
mapping=mapping,
|
|
||||||
raise_on_error=raise_on_unmapped,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_generic_class_tags(
|
class TagDecoder:
|
||||||
config: ClassesConfig,
|
def __init__(
|
||||||
term_registry: TermRegistry = default_term_registry,
|
self,
|
||||||
) -> List[data.Tag]:
|
mapping: Dict[str, List[data.Tag]],
|
||||||
"""Extract and build the list of tags for the generic class from config.
|
raise_on_unknown: bool = True,
|
||||||
|
):
|
||||||
|
self.mapping = mapping
|
||||||
|
self.raise_on_unknown = raise_on_unknown
|
||||||
|
|
||||||
Converts the list of `TagInfo` objects defined in `config.generic_class`
|
def __call__(self, class_name: str) -> List[data.Tag]:
|
||||||
into a list of `soundevent.data.Tag` objects using the term registry.
|
tags = self.mapping.get(class_name)
|
||||||
|
|
||||||
Parameters
|
if tags is None:
|
||||||
----------
|
if self.raise_on_unknown:
|
||||||
config : ClassesConfig
|
raise ValueError("Invalid class name")
|
||||||
The loaded classes configuration object.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance for term lookups. Defaults to the global
|
|
||||||
`batdetect2.targets.terms.registry`.
|
|
||||||
|
|
||||||
Returns
|
tags = []
|
||||||
-------
|
|
||||||
List[data.Tag]
|
|
||||||
The list of fully constructed tags representing the generic class.
|
|
||||||
|
|
||||||
Raises
|
return tags
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a term key specified in `config.generic_class` is not found in the
|
|
||||||
provided `term_registry`.
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
|
||||||
for tag_info in config.generic_class
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def load_classes_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> ClassesConfig:
|
|
||||||
"""Load the target classes configuration from a file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (YAML).
|
|
||||||
field : str, optional
|
|
||||||
If the classes configuration is nested under a specific key in the
|
|
||||||
file, specify the key here. Defaults to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ClassesConfig
|
|
||||||
The loaded and validated classes configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the ClassesConfig schema
|
|
||||||
or if class names are not unique.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=ClassesConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def load_encoder_from_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventEncoder:
|
|
||||||
"""Load a class encoder function directly from a configuration file.
|
|
||||||
|
|
||||||
This is a convenience function that combines loading the `ClassesConfig`
|
|
||||||
from a file and building the final `SoundEventEncoder` function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (e.g., YAML).
|
|
||||||
field : str, optional
|
|
||||||
If the classes configuration is nested under a specific key in the
|
|
||||||
file, specify the key here. Defaults to None.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance used for term lookups. Defaults to the
|
|
||||||
global `batdetect2.targets.terms.registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventEncoder
|
|
||||||
The final encoder function ready to classify annotations.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the ClassesConfig schema
|
|
||||||
or if class names are not unique.
|
|
||||||
KeyError
|
|
||||||
If a term key specified in the configuration is not found in the
|
|
||||||
provided `term_registry` during the build process.
|
|
||||||
"""
|
|
||||||
config = load_classes_config(path, field=field)
|
|
||||||
return build_sound_event_encoder(config, term_registry=term_registry)
|
|
||||||
|
|
||||||
|
|
||||||
def load_decoder_from_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
raise_on_unmapped: bool = False,
|
|
||||||
) -> SoundEventDecoder:
|
|
||||||
"""Load a class decoder function directly from a configuration file.
|
|
||||||
|
|
||||||
This is a convenience function that combines loading the `ClassesConfig`
|
|
||||||
from a file and building the final `SoundEventDecoder` function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (e.g., YAML).
|
|
||||||
field : str, optional
|
|
||||||
If the classes configuration is nested under a specific key in the
|
|
||||||
file, specify the key here. Defaults to None.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance used for term lookups. Defaults to the
|
|
||||||
global `batdetect2.targets.terms.registry`.
|
|
||||||
raise_on_unmapped : bool, default=False
|
|
||||||
If True, the returned decoder function will raise a ValueError if asked
|
|
||||||
to decode a class name that is not in the configuration. If False, it
|
|
||||||
will return an empty list for unmapped names.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventDecoder
|
|
||||||
The final decoder function ready to convert class names back into tags.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the ClassesConfig schema
|
|
||||||
or if class names are not unique.
|
|
||||||
KeyError
|
|
||||||
If a term key specified in the configuration is not found in the
|
|
||||||
provided `term_registry` during the build process.
|
|
||||||
"""
|
|
||||||
config = load_classes_config(path, field=field)
|
|
||||||
return build_sound_event_decoder(
|
|
||||||
config,
|
|
||||||
term_registry=term_registry,
|
|
||||||
raise_on_unmapped=raise_on_unmapped,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,307 +0,0 @@
|
|||||||
import logging
|
|
||||||
from functools import partial
|
|
||||||
from typing import List, Literal, Optional, Set
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.targets.terms import (
|
|
||||||
TagInfo,
|
|
||||||
TermRegistry,
|
|
||||||
default_term_registry,
|
|
||||||
get_tag_from_info,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.targets import SoundEventFilter
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"FilterConfig",
|
|
||||||
"FilterRule",
|
|
||||||
"build_sound_event_filter",
|
|
||||||
"build_filter_from_rule",
|
|
||||||
"load_filter_config",
|
|
||||||
"load_filter_from_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FilterRule(BaseConfig):
|
|
||||||
"""Defines a single rule for filtering sound event annotations.
|
|
||||||
|
|
||||||
Based on the `match_type`, this rule checks if the tags associated with a
|
|
||||||
sound event annotation meet certain criteria relative to the `tags` list
|
|
||||||
defined in this rule.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
match_type : Literal["any", "all", "exclude", "equal"]
|
|
||||||
Determines how the `tags` list is used:
|
|
||||||
- "any": Pass if the annotation has at least one tag from the list.
|
|
||||||
- "all": Pass if the annotation has all tags from the list (it can
|
|
||||||
have others too).
|
|
||||||
- "exclude": Pass if the annotation has none of the tags from the list.
|
|
||||||
- "equal": Pass if the annotation's tags are exactly the same set as
|
|
||||||
provided in the list.
|
|
||||||
tags : List[TagInfo]
|
|
||||||
A list of tags (defined using TagInfo for configuration) that this
|
|
||||||
rule operates on.
|
|
||||||
"""
|
|
||||||
|
|
||||||
match_type: Literal["any", "all", "exclude", "equal"]
|
|
||||||
tags: List[TagInfo]
|
|
||||||
|
|
||||||
|
|
||||||
def has_any_tag(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
tags: Set[data.Tag],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the annotation has at least one of the specified tags.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to check.
|
|
||||||
tags : Set[data.Tag]
|
|
||||||
The set of tags to look for.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the annotation has one or more tags from the specified set,
|
|
||||||
False otherwise.
|
|
||||||
"""
|
|
||||||
sound_event_tags = set(sound_event_annotation.tags)
|
|
||||||
return bool(tags & sound_event_tags)
|
|
||||||
|
|
||||||
|
|
||||||
def contains_tags(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
tags: Set[data.Tag],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the annotation contains all of the specified tags.
|
|
||||||
|
|
||||||
The annotation may have additional tags beyond those specified.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to check.
|
|
||||||
tags : Set[data.Tag]
|
|
||||||
The set of tags that must all be present in the annotation.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the annotation's tags are a superset of the specified tags,
|
|
||||||
False otherwise.
|
|
||||||
"""
|
|
||||||
sound_event_tags = set(sound_event_annotation.tags)
|
|
||||||
return tags <= sound_event_tags
|
|
||||||
|
|
||||||
|
|
||||||
def does_not_have_tags(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
tags: Set[data.Tag],
|
|
||||||
):
|
|
||||||
"""Check if the annotation has none of the specified tags.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to check.
|
|
||||||
tags : Set[data.Tag]
|
|
||||||
The set of tags that must *not* be present in the annotation.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the annotation has zero tags in common with the specified set,
|
|
||||||
False otherwise.
|
|
||||||
"""
|
|
||||||
return not has_any_tag(sound_event_annotation, tags)
|
|
||||||
|
|
||||||
|
|
||||||
def equal_tags(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
tags: Set[data.Tag],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the annotation's tags are exactly equal to the specified set.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to check.
|
|
||||||
tags : Set[data.Tag]
|
|
||||||
The exact set of tags the annotation must have.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the annotation's tags set is identical to the specified set,
|
|
||||||
False otherwise.
|
|
||||||
"""
|
|
||||||
sound_event_tags = set(sound_event_annotation.tags)
|
|
||||||
return tags == sound_event_tags
|
|
||||||
|
|
||||||
|
|
||||||
def build_filter_from_rule(
|
|
||||||
rule: FilterRule,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventFilter:
|
|
||||||
"""Creates a callable filter function from a single FilterRule.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
rule : FilterRule
|
|
||||||
The filter rule configuration object.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventFilter
|
|
||||||
A function that takes a SoundEventAnnotation and returns True if it
|
|
||||||
passes the rule, False otherwise.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the rule contains an invalid `match_type`.
|
|
||||||
"""
|
|
||||||
tag_set = {
|
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
|
||||||
for tag_info in rule.tags
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.match_type == "any":
|
|
||||||
return partial(has_any_tag, tags=tag_set)
|
|
||||||
|
|
||||||
if rule.match_type == "all":
|
|
||||||
return partial(contains_tags, tags=tag_set)
|
|
||||||
|
|
||||||
if rule.match_type == "exclude":
|
|
||||||
return partial(does_not_have_tags, tags=tag_set)
|
|
||||||
|
|
||||||
if rule.match_type == "equal":
|
|
||||||
return partial(equal_tags, tags=tag_set)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid match type {rule.match_type}. Valid types "
|
|
||||||
"are: 'any', 'all', 'exclude' and 'equal'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _passes_all_filters(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
filters: List[SoundEventFilter],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the annotation passes all provided filters.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to check.
|
|
||||||
filters : List[SoundEventFilter]
|
|
||||||
A list of filter functions to apply.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the annotation passes all filters, False otherwise.
|
|
||||||
"""
|
|
||||||
for filter_fn in filters:
|
|
||||||
if not filter_fn(sound_event_annotation):
|
|
||||||
logging.debug(
|
|
||||||
f"Sound event annotation {sound_event_annotation.uuid} "
|
|
||||||
f"excluded due to rule {filter_fn}",
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class FilterConfig(BaseConfig):
|
|
||||||
"""Configuration model for defining a list of filter rules.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
rules : List[FilterRule]
|
|
||||||
A list of FilterRule objects. An annotation must pass all rules in
|
|
||||||
this list to be considered valid by the filter built from this config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
rules: List[FilterRule] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_filter(
|
|
||||||
config: FilterConfig,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventFilter:
|
|
||||||
"""Builds a merged filter function from a FilterConfig object.
|
|
||||||
|
|
||||||
Creates individual filter functions for each rule in the configuration
|
|
||||||
and merges them using AND logic.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : FilterConfig
|
|
||||||
The configuration object containing the list of filter rules.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventFilter
|
|
||||||
A single callable filter function that applies all defined rules.
|
|
||||||
"""
|
|
||||||
filters = [
|
|
||||||
build_filter_from_rule(rule, term_registry=term_registry)
|
|
||||||
for rule in config.rules
|
|
||||||
]
|
|
||||||
return partial(_passes_all_filters, filters=filters)
|
|
||||||
|
|
||||||
|
|
||||||
def load_filter_config(
|
|
||||||
path: data.PathLike, field: Optional[str] = None
|
|
||||||
) -> FilterConfig:
|
|
||||||
"""Loads the filter configuration from a file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (YAML).
|
|
||||||
field : Optional[str], optional
|
|
||||||
If the filter configuration is nested under a specific key in the
|
|
||||||
file, specify the key here. Defaults to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
FilterConfig
|
|
||||||
The loaded and validated filter configuration object.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=FilterConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def load_filter_from_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> SoundEventFilter:
|
|
||||||
"""Loads filter configuration from a file and builds the filter function.
|
|
||||||
|
|
||||||
This is a convenience function that combines loading the configuration
|
|
||||||
and building the final callable filter function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (YAML).
|
|
||||||
field : Optional[str], optional
|
|
||||||
If the filter configuration is nested under a specific key in the
|
|
||||||
file, specify the key here. Defaults to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventFilter
|
|
||||||
The final merged filter function ready to be used.
|
|
||||||
"""
|
|
||||||
config = load_filter_config(path=path, field=field)
|
|
||||||
return build_sound_event_filter(config, term_registry=term_registry)
|
|
||||||
@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
|
|||||||
*geometric* aspect of target definition from *semantic* classification.
|
*geometric* aspect of target definition from *semantic* classification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
|
from typing import Annotated, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -30,7 +30,7 @@ from batdetect2.configs import BaseConfig
|
|||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import Position, Size
|
from batdetect2.typing.targets import Position, ROITargetMapper, Size
|
||||||
from batdetect2.utils.arrays import spec_to_xarray
|
from batdetect2.utils.arrays import spec_to_xarray
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -83,73 +83,6 @@ DEFAULT_ANCHOR = "bottom-left"
|
|||||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||||
|
|
||||||
|
|
||||||
class ROITargetMapper(Protocol):
|
|
||||||
"""Protocol defining the interface for ROI-to-target mapping.
|
|
||||||
|
|
||||||
Specifies the `encode` and `decode` methods required for converting a
|
|
||||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
|
||||||
position and a size vector) and for recovering an approximate ROI from that
|
|
||||||
representation.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
dimension_names : List[str]
|
|
||||||
A list containing the names of the dimensions in the `Size` array
|
|
||||||
returned by `encode` and expected by `decode`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dimension_names: List[str]
|
|
||||||
|
|
||||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
|
||||||
"""Encode a SoundEvent's geometry into a position and size.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEvent
|
|
||||||
The input sound event, which must have a geometry attribute.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[Position, Size]
|
|
||||||
A tuple containing:
|
|
||||||
- The reference position as (time, frequency) coordinates.
|
|
||||||
- A NumPy array with the calculated size dimensions.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the sound event does not have a geometry.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
|
||||||
"""Decode a position and size back into a geometric ROI.
|
|
||||||
|
|
||||||
Performs the inverse mapping: takes a reference position and size
|
|
||||||
dimensions and reconstructs a geometric representation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
position : Position
|
|
||||||
The reference position (time, frequency).
|
|
||||||
size : Size
|
|
||||||
NumPy array containing the size dimensions, matching the order
|
|
||||||
and meaning specified by `dimension_names`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Geometry
|
|
||||||
The reconstructed geometry, typically a `BoundingBox`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the `size` array has an unexpected shape or if reconstruction
|
|
||||||
fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class AnchorBBoxMapperConfig(BaseConfig):
|
class AnchorBBoxMapperConfig(BaseConfig):
|
||||||
"""Configuration for `AnchorBBoxMapper`.
|
"""Configuration for `AnchorBBoxMapper`.
|
||||||
|
|
||||||
@ -475,7 +408,10 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
|
|
||||||
ROIMapperConfig = Annotated[
|
ROIMapperConfig = Annotated[
|
||||||
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
|
Union[
|
||||||
|
AnchorBBoxMapperConfig,
|
||||||
|
PeakEnergyBBoxMapperConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""A discriminated union of all supported ROI mapper configurations.
|
"""A discriminated union of all supported ROI mapper configurations.
|
||||||
@ -553,7 +489,7 @@ def _build_bounding_box(
|
|||||||
) -> data.BoundingBox:
|
) -> data.BoundingBox:
|
||||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||||
|
|
||||||
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
Internal helper for `BBoxEncoder.decode`. Calculates the box
|
||||||
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||||
the input `pos` (time, freq) is located relative to the box (e.g.,
|
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||||
center, corner).
|
center, corner).
|
||||||
|
|||||||
@ -1,34 +1,11 @@
|
|||||||
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
"""Manages the vocabulary for defining training targets."""
|
||||||
|
|
||||||
This module provides the necessary tools to declare, register, and manage the
|
|
||||||
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
|
||||||
sub-package. It establishes a consistent vocabulary for filtering,
|
|
||||||
transforming, and classifying sound events based on their annotations (Tags).
|
|
||||||
|
|
||||||
The core component is the `TermRegistry`, which maps unique string keys
|
|
||||||
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
|
||||||
terms using simple, consistent keys in configuration files and code.
|
|
||||||
|
|
||||||
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
|
||||||
programmatically, or loaded from external configuration files (e.g., YAML).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from inspect import getmembers
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.configs import load_config
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"call_type",
|
"call_type",
|
||||||
"individual",
|
"individual",
|
||||||
"data_source",
|
"data_source",
|
||||||
"get_tag_from_info",
|
|
||||||
"TermInfo",
|
|
||||||
"TagInfo",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# The default key used to reference the 'generic_class' term.
|
# The default key used to reference the 'generic_class' term.
|
||||||
@ -96,430 +73,3 @@ terms.register_term_set(
|
|||||||
),
|
),
|
||||||
override_existing=True,
|
override_existing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TermRegistry(Mapping[str, data.Term]):
|
|
||||||
"""Manages a registry mapping unique keys to Term definitions.
|
|
||||||
|
|
||||||
This class acts as the central repository for the vocabulary of terms
|
|
||||||
used within the target definition process. It allows registering terms
|
|
||||||
with simple string keys and retrieving them consistently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
|
||||||
"""Initializes the TermRegistry.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
terms : dict[str, soundevent.data.Term], optional
|
|
||||||
An optional dictionary of initial key-to-Term mappings
|
|
||||||
to populate the registry with. Defaults to an empty registry.
|
|
||||||
"""
|
|
||||||
self._terms: Dict[str, data.Term] = terms or {}
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> data.Term:
|
|
||||||
return self._terms[key]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self._terms)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self._terms)
|
|
||||||
|
|
||||||
def add_term(self, key: str, term: data.Term) -> None:
|
|
||||||
"""Adds a Term object to the registry with the specified key.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique string key to associate with the term.
|
|
||||||
term : soundevent.data.Term
|
|
||||||
The soundevent.data.Term object to register.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a term with the provided key already exists in the
|
|
||||||
registry.
|
|
||||||
"""
|
|
||||||
if key in self._terms:
|
|
||||||
raise KeyError("A term with the provided key already exists.")
|
|
||||||
|
|
||||||
self._terms[key] = term
|
|
||||||
|
|
||||||
def get_term(self, key: str) -> data.Term:
|
|
||||||
"""Retrieves a registered term by its unique key.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique string key of the term to retrieve.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Term
|
|
||||||
The corresponding soundevent.data.Term object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If no term with the specified key is found, with a
|
|
||||||
helpful message suggesting listing available keys.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self._terms[key]
|
|
||||||
except KeyError as err:
|
|
||||||
raise KeyError(
|
|
||||||
"No term found for key "
|
|
||||||
f"'{key}'. Ensure it is registered or loaded. "
|
|
||||||
f"Available keys: {', '.join(self.get_keys())}"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
def add_custom_term(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
uri: Optional[str] = None,
|
|
||||||
label: Optional[str] = None,
|
|
||||||
definition: Optional[str] = None,
|
|
||||||
) -> data.Term:
|
|
||||||
"""Creates a new Term from attributes and adds it to the registry.
|
|
||||||
|
|
||||||
This is useful for defining terms directly in code or when loading
|
|
||||||
from configuration files where only attributes are provided.
|
|
||||||
|
|
||||||
If optional fields (`name`, `label`, `definition`) are not provided,
|
|
||||||
reasonable defaults are used (`key` for name/label, "Unknown" for
|
|
||||||
definition).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique string key for the new term.
|
|
||||||
name : str, optional
|
|
||||||
The name for the new term (defaults to `key`).
|
|
||||||
uri : str, optional
|
|
||||||
The URI for the new term (optional).
|
|
||||||
label : str, optional
|
|
||||||
The display label for the new term (defaults to `key`).
|
|
||||||
definition : str, optional
|
|
||||||
The definition for the new term (defaults to "Unknown").
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Term
|
|
||||||
The newly created and registered soundevent.data.Term object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a term with the provided key already exists.
|
|
||||||
"""
|
|
||||||
term = data.Term(
|
|
||||||
name=name or key,
|
|
||||||
label=label or key,
|
|
||||||
uri=uri,
|
|
||||||
definition=definition or "Unknown",
|
|
||||||
)
|
|
||||||
self.add_term(key, term)
|
|
||||||
return term
|
|
||||||
|
|
||||||
def get_keys(self) -> List[str]:
|
|
||||||
"""Returns a list of all keys currently registered.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[str]
|
|
||||||
A list of strings representing the keys of all registered terms.
|
|
||||||
"""
|
|
||||||
return list(self._terms.keys())
|
|
||||||
|
|
||||||
def get_terms(self) -> List[data.Term]:
|
|
||||||
"""Returns a list of all registered terms.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[soundevent.data.Term]
|
|
||||||
A list containing all registered Term objects.
|
|
||||||
"""
|
|
||||||
return list(self._terms.values())
|
|
||||||
|
|
||||||
def remove_key(self, key: str) -> None:
|
|
||||||
del self._terms[key]
|
|
||||||
|
|
||||||
|
|
||||||
default_term_registry = TermRegistry(
|
|
||||||
terms=dict(
|
|
||||||
[
|
|
||||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
|
||||||
("event", call_type),
|
|
||||||
("species", terms.scientific_name),
|
|
||||||
("individual", individual),
|
|
||||||
("data_source", data_source),
|
|
||||||
(GENERIC_CLASS_KEY, generic_class),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
"""The default, globally accessible TermRegistry instance.
|
|
||||||
|
|
||||||
It is pre-populated with standard terms from `soundevent.terms` and common
|
|
||||||
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
|
||||||
Functions in this module use this registry by default unless another instance
|
|
||||||
is explicitly passed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_term_from_key(
|
|
||||||
key: str,
|
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> data.Term:
|
|
||||||
"""Convenience function to retrieve a term by key from a registry.
|
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
|
||||||
instance is provided.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique key of the term to retrieve.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to search in. Defaults to the global
|
|
||||||
`registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Term
|
|
||||||
The corresponding soundevent.data.Term object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If the key is not found in the specified registry.
|
|
||||||
"""
|
|
||||||
term = terms.get_term(key)
|
|
||||||
|
|
||||||
if term:
|
|
||||||
return term
|
|
||||||
|
|
||||||
term_registry = term_registry or default_term_registry
|
|
||||||
return term_registry.get_term(key)
|
|
||||||
|
|
||||||
|
|
||||||
def get_term_keys(
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Convenience function to get all registered keys from a registry.
|
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
|
||||||
instance is provided.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[str]
|
|
||||||
A list of strings representing the keys of all registered terms.
|
|
||||||
"""
|
|
||||||
return term_registry.get_keys()
|
|
||||||
|
|
||||||
|
|
||||||
def get_terms(
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> List[data.Term]:
|
|
||||||
"""Convenience function to get all registered terms from a registry.
|
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
|
||||||
instance is provided.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[soundevent.data.Term]
|
|
||||||
A list containing all registered Term objects.
|
|
||||||
"""
|
|
||||||
return term_registry.get_terms()
|
|
||||||
|
|
||||||
|
|
||||||
class TagInfo(BaseModel):
|
|
||||||
"""Represents information needed to define a specific Tag.
|
|
||||||
|
|
||||||
This model is typically used in configuration files (e.g., YAML) to
|
|
||||||
specify tags used for filtering, target class definition, or associating
|
|
||||||
tags with output classes. It links a tag value to a term definition
|
|
||||||
via the term's registry key.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
value : str
|
|
||||||
The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
|
||||||
key : str, default="class"
|
|
||||||
The key (alias) of the term associated with this tag, as
|
|
||||||
registered in the TermRegistry. Defaults to "class", implying
|
|
||||||
it represents a classification target label by default.
|
|
||||||
"""
|
|
||||||
|
|
||||||
value: str
|
|
||||||
key: str = GENERIC_CLASS_KEY
|
|
||||||
|
|
||||||
|
|
||||||
def get_tag_from_info(
|
|
||||||
tag_info: TagInfo,
|
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> data.Tag:
|
|
||||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
|
||||||
|
|
||||||
Looks up the term using the key in the provided `tag_info` from the
|
|
||||||
specified registry and constructs a Tag object.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
tag_info : TagInfo
|
|
||||||
The TagInfo object containing the value and term key.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use for term lookup. Defaults to the
|
|
||||||
global `registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Tag
|
|
||||||
A soundevent.data.Tag object corresponding to the input info.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If the term key specified in `tag_info.key` is not found
|
|
||||||
in the registry.
|
|
||||||
"""
|
|
||||||
term_registry = term_registry or default_term_registry
|
|
||||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
|
||||||
return data.Tag(term=term, value=tag_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
class TermInfo(BaseModel):
|
|
||||||
"""Represents the definition of a Term within a configuration file.
|
|
||||||
|
|
||||||
This model allows users to define custom terms directly in configuration
|
|
||||||
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
|
||||||
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique key (alias) that will be used to register and
|
|
||||||
reference this term.
|
|
||||||
label : str, optional
|
|
||||||
The optional display label for the term. Defaults to `key`
|
|
||||||
if not provided during registration.
|
|
||||||
name : str, optional
|
|
||||||
The optional formal name for the term. Defaults to `key`
|
|
||||||
if not provided during registration.
|
|
||||||
uri : str, optional
|
|
||||||
The optional URI identifying the term (e.g., from a standard
|
|
||||||
vocabulary).
|
|
||||||
definition : str, optional
|
|
||||||
The optional textual definition of the term. Defaults to
|
|
||||||
"Unknown" if not provided during registration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
label: Optional[str] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
uri: Optional[str] = None
|
|
||||||
definition: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TermConfig(BaseModel):
|
|
||||||
"""Pydantic schema for loading a list of term definitions from config.
|
|
||||||
|
|
||||||
This model typically corresponds to a section in a configuration file
|
|
||||||
(e.g., YAML) containing a list of term definitions to be registered.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
terms : list[TermInfo]
|
|
||||||
A list of TermInfo objects, each defining a term to be
|
|
||||||
registered. Defaults to an empty list.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
Example YAML structure:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
terms:
|
|
||||||
- key: species
|
|
||||||
uri: dwc:scientificName
|
|
||||||
label: Scientific Name
|
|
||||||
- key: my_custom_term
|
|
||||||
name: My Custom Term
|
|
||||||
definition: Describes a specific project attribute.
|
|
||||||
# ... more TermInfo definitions
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
terms: List[TermInfo] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_terms_from_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> Dict[str, data.Term]:
|
|
||||||
"""Loads term definitions from a configuration file and registers them.
|
|
||||||
|
|
||||||
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
|
||||||
extracts the list of TermInfo definitions, and adds each one as a
|
|
||||||
custom term to the specified TermRegistry instance.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
The path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Optional key indicating a specific section within the config
|
|
||||||
file where the 'terms' list is located. If None, expects the
|
|
||||||
list directly at the top level or within a structure matching
|
|
||||||
TermConfig schema.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to add the loaded terms to. Defaults to
|
|
||||||
the global `registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict[str, soundevent.data.Term]
|
|
||||||
A dictionary mapping the keys of the newly added terms to their
|
|
||||||
corresponding Term objects.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the TermConfig schema.
|
|
||||||
KeyError
|
|
||||||
If a term key loaded from the config conflicts with a key
|
|
||||||
already present in the registry.
|
|
||||||
"""
|
|
||||||
data = load_config(path, schema=TermConfig, field=field)
|
|
||||||
return {
|
|
||||||
info.key: term_registry.add_custom_term(
|
|
||||||
info.key,
|
|
||||||
name=info.name,
|
|
||||||
uri=info.uri,
|
|
||||||
label=info.label,
|
|
||||||
definition=info.definition,
|
|
||||||
)
|
|
||||||
for info in data.terms
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def register_term(
|
|
||||||
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
|
||||||
) -> None:
|
|
||||||
registry.add_term(key, term)
|
|
||||||
|
|||||||
@ -1,708 +0,0 @@
|
|||||||
import importlib
|
|
||||||
from functools import partial
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.targets.terms import (
|
|
||||||
TagInfo,
|
|
||||||
TermRegistry,
|
|
||||||
get_tag_from_info,
|
|
||||||
get_term_from_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"DerivationRegistry",
|
|
||||||
"DeriveTagRule",
|
|
||||||
"MapValueRule",
|
|
||||||
"ReplaceRule",
|
|
||||||
"SoundEventTransformation",
|
|
||||||
"TransformConfig",
|
|
||||||
"build_transform_from_rule",
|
|
||||||
"build_transformation_from_config",
|
|
||||||
"default_derivation_registry",
|
|
||||||
"get_derivation",
|
|
||||||
"load_transformation_config",
|
|
||||||
"load_transformation_from_config",
|
|
||||||
"register_derivation",
|
|
||||||
]
|
|
||||||
|
|
||||||
SoundEventTransformation = Callable[
|
|
||||||
[data.SoundEventAnnotation], data.SoundEventAnnotation
|
|
||||||
]
|
|
||||||
"""Type alias for a sound event transformation function.
|
|
||||||
|
|
||||||
A function that accepts a sound event annotation object and returns a
|
|
||||||
(potentially) modified sound event annotation object. Transformations
|
|
||||||
should generally return a copy of the annotation rather than modifying
|
|
||||||
it in place.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
Derivation = Callable[[str], str]
|
|
||||||
"""Type alias for a derivation function.
|
|
||||||
|
|
||||||
A function that accepts a single string (typically a tag value) and returns
|
|
||||||
a new string (the derived value).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class MapValueRule(BaseConfig):
|
|
||||||
"""Configuration for mapping specific values of a source term.
|
|
||||||
|
|
||||||
This rule replaces tags matching a specific term and one of the
|
|
||||||
original values with a new tag (potentially having a different term)
|
|
||||||
containing the corresponding replacement value. Useful for standardizing
|
|
||||||
or grouping tag values.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
rule_type : Literal["map_value"]
|
|
||||||
Discriminator field identifying this rule type.
|
|
||||||
source_term_key : str
|
|
||||||
The key (registered in `TermRegistry`) of the term whose tags' values
|
|
||||||
should be checked against the `value_mapping`.
|
|
||||||
value_mapping : Dict[str, str]
|
|
||||||
A dictionary mapping original string values to replacement string
|
|
||||||
values. Only tags whose value is a key in this dictionary will be
|
|
||||||
affected.
|
|
||||||
target_term_key : str, optional
|
|
||||||
The key (registered in `TermRegistry`) for the term of the *output*
|
|
||||||
tag. If None (default), the output tag uses the same term as the
|
|
||||||
source (`source_term_key`). If provided, the term of the affected
|
|
||||||
tag is changed to this target term upon replacement.
|
|
||||||
"""
|
|
||||||
|
|
||||||
rule_type: Literal["map_value"] = "map_value"
|
|
||||||
source_term_key: str
|
|
||||||
value_mapping: Dict[str, str]
|
|
||||||
target_term_key: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def map_value_transform(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
source_term: data.Term,
|
|
||||||
target_term: data.Term,
|
|
||||||
mapping: Dict[str, str],
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
"""Apply a value mapping transformation to an annotation's tags.
|
|
||||||
|
|
||||||
Iterates through the annotation's tags. If a tag matches the `source_term`
|
|
||||||
and its value is found in the `mapping`, it is replaced by a new tag with
|
|
||||||
the `target_term` and the mapped value. Other tags are kept unchanged.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to transform.
|
|
||||||
source_term : data.Term
|
|
||||||
The term of tags whose values should be mapped.
|
|
||||||
target_term : data.Term
|
|
||||||
The term to use for the newly created tags after mapping.
|
|
||||||
mapping : Dict[str, str]
|
|
||||||
The dictionary mapping original values to new values.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.SoundEventAnnotation
|
|
||||||
A new annotation object with the transformed tags.
|
|
||||||
"""
|
|
||||||
tags = []
|
|
||||||
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
if tag.term != source_term or tag.value not in mapping:
|
|
||||||
tags.append(tag)
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_value = mapping[tag.value]
|
|
||||||
tags.append(data.Tag(term=target_term, value=new_value))
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
||||||
|
|
||||||
|
|
||||||
class DeriveTagRule(BaseConfig):
|
|
||||||
"""Configuration for deriving a new tag from an existing tag's value.
|
|
||||||
|
|
||||||
This rule applies a specified function (`derivation_function`) to the
|
|
||||||
value of tags matching the `source_term_key`. It then adds a new tag
|
|
||||||
with the `target_term_key` and the derived value.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
rule_type : Literal["derive_tag"]
|
|
||||||
Discriminator field identifying this rule type.
|
|
||||||
source_term_key : str
|
|
||||||
The key (registered in `TermRegistry`) of the term whose tag values
|
|
||||||
will be used as input to the derivation function.
|
|
||||||
derivation_function : str
|
|
||||||
The name/key identifying the derivation function to use. This can be
|
|
||||||
a key registered in the `DerivationRegistry` or, if
|
|
||||||
`import_derivation` is True, a full Python path like
|
|
||||||
`'my_module.my_submodule.my_function'`.
|
|
||||||
target_term_key : str, optional
|
|
||||||
The key (registered in `TermRegistry`) for the term of the new tag
|
|
||||||
that will be created with the derived value. If None (default), the
|
|
||||||
derived tag uses the same term as the source (`source_term_key`),
|
|
||||||
effectively performing an in-place value transformation.
|
|
||||||
import_derivation : bool, default=False
|
|
||||||
If True, treat `derivation_function` as a Python import path and
|
|
||||||
attempt to dynamically import it if not found in the registry.
|
|
||||||
Requires the function to be accessible in the Python environment.
|
|
||||||
keep_source : bool, default=True
|
|
||||||
If True, the original source tag (whose value was used for derivation)
|
|
||||||
is kept in the annotation's tag list alongside the newly derived tag.
|
|
||||||
If False, the original source tag is removed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
rule_type: Literal["derive_tag"] = "derive_tag"
|
|
||||||
source_term_key: str
|
|
||||||
derivation_function: str
|
|
||||||
target_term_key: Optional[str] = None
|
|
||||||
import_derivation: bool = False
|
|
||||||
keep_source: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
def derivation_tag_transform(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
source_term: data.Term,
|
|
||||||
target_term: data.Term,
|
|
||||||
derivation: Derivation,
|
|
||||||
keep_source: bool = True,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
"""Apply a derivation transformation to an annotation's tags.
|
|
||||||
|
|
||||||
Iterates through the annotation's tags. For each tag matching the
|
|
||||||
`source_term`, its value is passed to the `derivation` function.
|
|
||||||
A new tag is created with the `target_term` and the derived value,
|
|
||||||
and added to the output tag list. The original source tag is kept
|
|
||||||
or discarded based on `keep_source`. Other tags are kept unchanged.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to transform.
|
|
||||||
source_term : data.Term
|
|
||||||
The term of tags whose values serve as input for the derivation.
|
|
||||||
target_term : data.Term
|
|
||||||
The term to use for the newly created derived tags.
|
|
||||||
derivation : Derivation
|
|
||||||
The function to apply to the source tag's value.
|
|
||||||
keep_source : bool, default=True
|
|
||||||
Whether to keep the original source tag in the output.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.SoundEventAnnotation
|
|
||||||
A new annotation object with the transformed tags (including derived
|
|
||||||
ones).
|
|
||||||
"""
|
|
||||||
tags = []
|
|
||||||
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
if tag.term != source_term:
|
|
||||||
tags.append(tag)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if keep_source:
|
|
||||||
tags.append(tag)
|
|
||||||
|
|
||||||
new_value = derivation(tag.value)
|
|
||||||
tags.append(data.Tag(term=target_term, value=new_value))
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceRule(BaseConfig):
|
|
||||||
"""Configuration for exactly replacing one specific tag with another.
|
|
||||||
|
|
||||||
This rule looks for an exact match of the `original` tag (both term and
|
|
||||||
value) and replaces it with the specified `replacement` tag.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
rule_type : Literal["replace"]
|
|
||||||
Discriminator field identifying this rule type.
|
|
||||||
original : TagInfo
|
|
||||||
The exact tag to search for, defined using its value and term key.
|
|
||||||
replacement : TagInfo
|
|
||||||
The tag to substitute in place of the original tag, defined using
|
|
||||||
its value and term key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
rule_type: Literal["replace"] = "replace"
|
|
||||||
original: TagInfo
|
|
||||||
replacement: TagInfo
|
|
||||||
|
|
||||||
|
|
||||||
def replace_tag_transform(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
source: data.Tag,
|
|
||||||
target: data.Tag,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
"""Apply an exact tag replacement transformation.
|
|
||||||
|
|
||||||
Iterates through the annotation's tags. If a tag exactly matches the
|
|
||||||
`source` tag, it is replaced by the `target` tag. Other tags are kept
|
|
||||||
unchanged.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event_annotation : data.SoundEventAnnotation
|
|
||||||
The annotation to transform.
|
|
||||||
source : data.Tag
|
|
||||||
The exact tag to find and replace.
|
|
||||||
target : data.Tag
|
|
||||||
The tag to replace the source tag with.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.SoundEventAnnotation
|
|
||||||
A new annotation object with the replaced tag (if found).
|
|
||||||
"""
|
|
||||||
tags = []
|
|
||||||
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
if tag == source:
|
|
||||||
tags.append(target)
|
|
||||||
else:
|
|
||||||
tags.append(tag)
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
||||||
|
|
||||||
|
|
||||||
class TransformConfig(BaseConfig):
|
|
||||||
"""Configuration model for defining a sequence of transformation rules.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
|
|
||||||
A list of transformation rules to apply. The rules are applied
|
|
||||||
sequentially in the order they appear in the list. The output of
|
|
||||||
one rule becomes the input for the next. The `rule_type` field
|
|
||||||
discriminates between the different rule models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
rules: List[
|
|
||||||
Annotated[
|
|
||||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
|
||||||
Field(discriminator="rule_type"),
|
|
||||||
]
|
|
||||||
] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DerivationRegistry(Mapping[str, Derivation]):
|
|
||||||
"""A registry for managing named derivation functions.
|
|
||||||
|
|
||||||
Derivation functions are callables that take a string value and return
|
|
||||||
a transformed string value, used by `DeriveTagRule`. This registry
|
|
||||||
allows functions to be registered with a key and retrieved later.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize an empty DerivationRegistry."""
|
|
||||||
self._derivations: Dict[str, Derivation] = {}
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> Derivation:
|
|
||||||
"""Retrieve a derivation function by key."""
|
|
||||||
return self._derivations[key]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Return the number of registered derivation functions."""
|
|
||||||
return len(self._derivations)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
"""Return an iterator over the keys of registered functions."""
|
|
||||||
return iter(self._derivations)
|
|
||||||
|
|
||||||
def register(self, key: str, derivation: Derivation) -> None:
|
|
||||||
"""Register a derivation function with a unique key.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique key to associate with the derivation function.
|
|
||||||
derivation : Derivation
|
|
||||||
The callable derivation function (takes str, returns str).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a derivation function with the same key is already registered.
|
|
||||||
"""
|
|
||||||
if key in self._derivations:
|
|
||||||
raise KeyError(
|
|
||||||
f"A derivation with the provided key {key} already exists"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._derivations[key] = derivation
|
|
||||||
|
|
||||||
def get_derivation(self, key: str) -> Derivation:
|
|
||||||
"""Retrieve a derivation function by its registered key.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The key of the derivation function to retrieve.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Derivation
|
|
||||||
The requested derivation function.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If no derivation function with the specified key is registered.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self._derivations[key]
|
|
||||||
except KeyError as err:
|
|
||||||
raise KeyError(
|
|
||||||
f"No derivation with key {key} is registered."
|
|
||||||
) from err
|
|
||||||
|
|
||||||
def get_keys(self) -> List[str]:
|
|
||||||
"""Get a list of all registered derivation function keys.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[str]
|
|
||||||
The keys of all registered functions.
|
|
||||||
"""
|
|
||||||
return list(self._derivations.keys())
|
|
||||||
|
|
||||||
def get_derivations(self) -> List[Derivation]:
|
|
||||||
"""Get a list of all registered derivation functions.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[Derivation]
|
|
||||||
The registered derivation function objects.
|
|
||||||
"""
|
|
||||||
return list(self._derivations.values())
|
|
||||||
|
|
||||||
|
|
||||||
default_derivation_registry = DerivationRegistry()
|
|
||||||
"""Global instance of the DerivationRegistry.
|
|
||||||
|
|
||||||
Register custom derivation functions here to make them available by key
|
|
||||||
in `DeriveTagRule` configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_derivation(
|
|
||||||
key: str,
|
|
||||||
import_derivation: bool = False,
|
|
||||||
registry: Optional[DerivationRegistry] = None,
|
|
||||||
):
|
|
||||||
"""Retrieve a derivation function by key, optionally importing it.
|
|
||||||
|
|
||||||
First attempts to find the function in the provided `registry`.
|
|
||||||
If not found and `import_derivation` is True, attempts to dynamically
|
|
||||||
import the function using the `key` as a full Python path
|
|
||||||
(e.g., 'my_module.submodule.my_func').
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The key or Python path of the derivation function.
|
|
||||||
import_derivation : bool, default=False
|
|
||||||
If True, attempt dynamic import if key is not in the registry.
|
|
||||||
registry : DerivationRegistry, optional
|
|
||||||
The registry instance to check first. Defaults to the global
|
|
||||||
`derivation_registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Derivation
|
|
||||||
The requested derivation function.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If the key is not found in the registry and either
|
|
||||||
`import_derivation` is False or the dynamic import fails.
|
|
||||||
ImportError
|
|
||||||
If dynamic import fails specifically due to module not found.
|
|
||||||
AttributeError
|
|
||||||
If dynamic import fails because the function name isn't in the module.
|
|
||||||
"""
|
|
||||||
registry = registry or default_derivation_registry
|
|
||||||
|
|
||||||
if not import_derivation or key in registry:
|
|
||||||
return registry.get_derivation(key)
|
|
||||||
|
|
||||||
try:
|
|
||||||
module_path, func_name = key.rsplit(".", 1)
|
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
func = getattr(module, func_name)
|
|
||||||
return func
|
|
||||||
except ImportError as err:
|
|
||||||
raise KeyError(
|
|
||||||
f"Unable to load derivation '{key}'. Check the path and ensure "
|
|
||||||
"it points to a valid callable function in an importable module."
|
|
||||||
) from err
|
|
||||||
|
|
||||||
|
|
||||||
TranformationRule = Annotated[
|
|
||||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
|
||||||
Field(discriminator="rule_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_transform_from_rule(
|
|
||||||
rule: TranformationRule,
|
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> SoundEventTransformation:
|
|
||||||
"""Build a specific SoundEventTransformation function from a rule config.
|
|
||||||
|
|
||||||
Selects the appropriate transformation logic based on the rule's
|
|
||||||
`rule_type`, fetches necessary terms and derivation functions, and
|
|
||||||
returns a partially applied function ready to transform an annotation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
|
|
||||||
The configuration object for a single transformation rule.
|
|
||||||
registry : DerivationRegistry, optional
|
|
||||||
The derivation registry to use for `DeriveTagRule`. Defaults to the
|
|
||||||
global `derivation_registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventTransformation
|
|
||||||
A callable that applies the specified rule to a SoundEventAnnotation.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If required term keys or derivation keys are not found.
|
|
||||||
ValueError
|
|
||||||
If the rule has an unknown `rule_type`.
|
|
||||||
ImportError, AttributeError, TypeError
|
|
||||||
If dynamic import of a derivation function fails.
|
|
||||||
"""
|
|
||||||
if rule.rule_type == "replace":
|
|
||||||
source = get_tag_from_info(
|
|
||||||
rule.original,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
target = get_tag_from_info(
|
|
||||||
rule.replacement,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
return partial(replace_tag_transform, source=source, target=target)
|
|
||||||
|
|
||||||
if rule.rule_type == "derive_tag":
|
|
||||||
source_term = get_term_from_key(
|
|
||||||
rule.source_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
target_term = (
|
|
||||||
get_term_from_key(
|
|
||||||
rule.target_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
if rule.target_term_key
|
|
||||||
else source_term
|
|
||||||
)
|
|
||||||
derivation = get_derivation(
|
|
||||||
key=rule.derivation_function,
|
|
||||||
import_derivation=rule.import_derivation,
|
|
||||||
registry=derivation_registry,
|
|
||||||
)
|
|
||||||
return partial(
|
|
||||||
derivation_tag_transform,
|
|
||||||
source_term=source_term,
|
|
||||||
target_term=target_term,
|
|
||||||
derivation=derivation,
|
|
||||||
keep_source=rule.keep_source,
|
|
||||||
)
|
|
||||||
|
|
||||||
if rule.rule_type == "map_value":
|
|
||||||
source_term = get_term_from_key(
|
|
||||||
rule.source_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
target_term = (
|
|
||||||
get_term_from_key(
|
|
||||||
rule.target_term_key,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
if rule.target_term_key
|
|
||||||
else source_term
|
|
||||||
)
|
|
||||||
return partial(
|
|
||||||
map_value_transform,
|
|
||||||
source_term=source_term,
|
|
||||||
target_term=target_term,
|
|
||||||
mapping=rule.value_mapping,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle unknown rule type
|
|
||||||
valid_options = ["replace", "derive_tag", "map_value"]
|
|
||||||
# Should be caught by Pydantic validation, but good practice
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
|
||||||
f"Valid options are: {valid_options}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_transformation_from_config(
|
|
||||||
config: TransformConfig,
|
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> SoundEventTransformation:
|
|
||||||
"""Build a composite transformation function from a TransformConfig.
|
|
||||||
|
|
||||||
Creates a sequence of individual transformation functions based on the
|
|
||||||
rules defined in the configuration. Returns a single function that
|
|
||||||
applies these transformations sequentially to an annotation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : TransformConfig
|
|
||||||
The configuration object containing the list of transformation rules.
|
|
||||||
derivation_reg : DerivationRegistry, optional
|
|
||||||
The derivation registry to use when building `DeriveTagRule`
|
|
||||||
transformations. Defaults to the global `derivation_registry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventTransformation
|
|
||||||
A single function that applies all configured transformations in order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
transforms = [
|
|
||||||
build_transform_from_rule(
|
|
||||||
rule,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
for rule in config.rules
|
|
||||||
]
|
|
||||||
|
|
||||||
return partial(apply_sequence_of_transforms, transforms=transforms)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_sequence_of_transforms(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
transforms: list[SoundEventTransformation],
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
for transform in transforms:
|
|
||||||
sound_event_annotation = transform(sound_event_annotation)
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
|
|
||||||
def load_transformation_config(
|
|
||||||
path: data.PathLike, field: Optional[str] = None
|
|
||||||
) -> TransformConfig:
|
|
||||||
"""Load the transformation configuration from a file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (YAML).
|
|
||||||
field : str, optional
|
|
||||||
If the transformation configuration is nested under a specific key
|
|
||||||
in the file, specify the key here. Defaults to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
TransformConfig
|
|
||||||
The loaded and validated transformation configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the TransformConfig schema.
|
|
||||||
"""
|
|
||||||
return load_config(path=path, schema=TransformConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def load_transformation_from_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
|
||||||
term_registry: Optional[TermRegistry] = None,
|
|
||||||
) -> SoundEventTransformation:
|
|
||||||
"""Load transformation config from a file and build the final function.
|
|
||||||
|
|
||||||
This is a convenience function that combines loading the configuration
|
|
||||||
and building the final callable transformation function that applies
|
|
||||||
all rules sequentially.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (YAML).
|
|
||||||
field : str, optional
|
|
||||||
If the transformation configuration is nested under a specific key
|
|
||||||
in the file, specify the key here. Defaults to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SoundEventTransformation
|
|
||||||
The final composite transformation function ready to be used.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the TransformConfig schema.
|
|
||||||
KeyError
|
|
||||||
If required term keys or derivation keys specified in the config
|
|
||||||
are not found during the build process.
|
|
||||||
ImportError, AttributeError, TypeError
|
|
||||||
If dynamic import of a derivation function specified in the config
|
|
||||||
fails.
|
|
||||||
"""
|
|
||||||
config = load_transformation_config(path=path, field=field)
|
|
||||||
return build_transformation_from_config(
|
|
||||||
config,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_derivation(
|
|
||||||
key: str,
|
|
||||||
derivation: Derivation,
|
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Register a new derivation function in the global registry.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
The unique key to associate with the derivation function.
|
|
||||||
derivation : Derivation
|
|
||||||
The callable derivation function (takes str, returns str).
|
|
||||||
derivation_registry : DerivationRegistry, optional
|
|
||||||
The registry instance to register the derivation function with.
|
|
||||||
Defaults to the global `derivation_registry`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If a derivation function with the same key is already registered.
|
|
||||||
"""
|
|
||||||
derivation_registry = derivation_registry or default_derivation_registry
|
|
||||||
derivation_registry.register(key, derivation)
|
|
||||||
@ -33,10 +33,6 @@ from batdetect2.train.losses import (
|
|||||||
SizeLossConfig,
|
SizeLossConfig,
|
||||||
build_loss,
|
build_loss,
|
||||||
)
|
)
|
||||||
from batdetect2.train.preprocess import (
|
|
||||||
generate_train_example,
|
|
||||||
preprocess_annotations,
|
|
||||||
)
|
|
||||||
from batdetect2.train.train import (
|
from batdetect2.train.train import (
|
||||||
build_train_dataset,
|
build_train_dataset,
|
||||||
build_train_loader,
|
build_train_loader,
|
||||||
@ -74,14 +70,12 @@ __all__ = [
|
|||||||
"build_trainer",
|
"build_trainer",
|
||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
"build_val_loader",
|
"build_val_loader",
|
||||||
"generate_train_example",
|
|
||||||
"load_full_training_config",
|
"load_full_training_config",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_audio",
|
"mix_audio",
|
||||||
"preprocess_annotations",
|
|
||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
|
|||||||
@ -44,7 +44,7 @@ AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
|||||||
class MixAugmentationConfig(BaseConfig):
|
class MixAugmentationConfig(BaseConfig):
|
||||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||||
|
|
||||||
augmentation_type: Literal["mix_audio"] = "mix_audio"
|
name: Literal["mix_audio"] = "mix_audio"
|
||||||
|
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
"""Probability of applying this augmentation to an example."""
|
"""Probability of applying this augmentation to an example."""
|
||||||
@ -140,7 +140,7 @@ def combine_clip_annotations(
|
|||||||
class EchoAugmentationConfig(BaseConfig):
|
class EchoAugmentationConfig(BaseConfig):
|
||||||
"""Configuration for adding synthetic echo/reverb."""
|
"""Configuration for adding synthetic echo/reverb."""
|
||||||
|
|
||||||
augmentation_type: Literal["add_echo"] = "add_echo"
|
name: Literal["add_echo"] = "add_echo"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_delay: float = 0.005
|
max_delay: float = 0.005
|
||||||
min_weight: float = 0.0
|
min_weight: float = 0.0
|
||||||
@ -187,7 +187,7 @@ def add_echo(
|
|||||||
class VolumeAugmentationConfig(BaseConfig):
|
class VolumeAugmentationConfig(BaseConfig):
|
||||||
"""Configuration for random volume scaling of the spectrogram."""
|
"""Configuration for random volume scaling of the spectrogram."""
|
||||||
|
|
||||||
augmentation_type: Literal["scale_volume"] = "scale_volume"
|
name: Literal["scale_volume"] = "scale_volume"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
min_scaling: float = 0.0
|
min_scaling: float = 0.0
|
||||||
max_scaling: float = 2.0
|
max_scaling: float = 2.0
|
||||||
@ -214,7 +214,7 @@ def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
class WarpAugmentationConfig(BaseConfig):
|
class WarpAugmentationConfig(BaseConfig):
|
||||||
augmentation_type: Literal["warp"] = "warp"
|
name: Literal["warp"] = "warp"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
delta: float = 0.04
|
delta: float = 0.04
|
||||||
|
|
||||||
@ -296,7 +296,7 @@ def warp_spectrogram(
|
|||||||
|
|
||||||
|
|
||||||
class TimeMaskAugmentationConfig(BaseConfig):
|
class TimeMaskAugmentationConfig(BaseConfig):
|
||||||
augmentation_type: Literal["mask_time"] = "mask_time"
|
name: Literal["mask_time"] = "mask_time"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_perc: float = 0.05
|
max_perc: float = 0.05
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
@ -353,7 +353,7 @@ def mask_time(
|
|||||||
|
|
||||||
|
|
||||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||||
augmentation_type: Literal["mask_freq"] = "mask_freq"
|
name: Literal["mask_freq"] = "mask_freq"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_perc: float = 0.10
|
max_perc: float = 0.10
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
@ -414,7 +414,7 @@ AudioAugmentationConfig = Annotated[
|
|||||||
MixAugmentationConfig,
|
MixAugmentationConfig,
|
||||||
EchoAugmentationConfig,
|
EchoAugmentationConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="augmentation_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ SpectrogramAugmentationConfig = Annotated[
|
|||||||
FrequencyMaskAugmentationConfig,
|
FrequencyMaskAugmentationConfig,
|
||||||
TimeMaskAugmentationConfig,
|
TimeMaskAugmentationConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="augmentation_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
@ -437,7 +437,7 @@ AugmentationConfig = Annotated[
|
|||||||
FrequencyMaskAugmentationConfig,
|
FrequencyMaskAugmentationConfig,
|
||||||
TimeMaskAugmentationConfig,
|
TimeMaskAugmentationConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="augmentation_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of individual augmentation config."""
|
"""Type alias for the discriminated union of individual augmentation config."""
|
||||||
|
|
||||||
@ -485,7 +485,7 @@ def build_augmentation_from_config(
|
|||||||
audio_source: Optional[AudioSource] = None,
|
audio_source: Optional[AudioSource] = None,
|
||||||
) -> Optional[Augmentation]:
|
) -> Optional[Augmentation]:
|
||||||
"""Factory function to build a single augmentation from its config."""
|
"""Factory function to build a single augmentation from its config."""
|
||||||
if config.augmentation_type == "mix_audio":
|
if config.name == "mix_audio":
|
||||||
if audio_source is None:
|
if audio_source is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Mix audio augmentation ('mix_audio') requires an "
|
"Mix audio augmentation ('mix_audio') requires an "
|
||||||
@ -500,31 +500,31 @@ def build_augmentation_from_config(
|
|||||||
max_weight=config.max_weight,
|
max_weight=config.max_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.augmentation_type == "add_echo":
|
if config.name == "add_echo":
|
||||||
return AddEcho(
|
return AddEcho(
|
||||||
max_delay=int(config.max_delay * samplerate),
|
max_delay=int(config.max_delay * samplerate),
|
||||||
min_weight=config.min_weight,
|
min_weight=config.min_weight,
|
||||||
max_weight=config.max_weight,
|
max_weight=config.max_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.augmentation_type == "scale_volume":
|
if config.name == "scale_volume":
|
||||||
return ScaleVolume(
|
return ScaleVolume(
|
||||||
max_scaling=config.max_scaling,
|
max_scaling=config.max_scaling,
|
||||||
min_scaling=config.min_scaling,
|
min_scaling=config.min_scaling,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.augmentation_type == "warp":
|
if config.name == "warp":
|
||||||
return WarpSpectrogram(
|
return WarpSpectrogram(
|
||||||
delta=config.delta,
|
delta=config.delta,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.augmentation_type == "mask_time":
|
if config.name == "mask_time":
|
||||||
return MaskTime(
|
return MaskTime(
|
||||||
max_perc=config.max_perc,
|
max_perc=config.max_perc,
|
||||||
max_masks=config.max_masks,
|
max_masks=config.max_masks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.augmentation_type == "mask_freq":
|
if config.name == "mask_freq":
|
||||||
return MaskFrequency(
|
return MaskFrequency(
|
||||||
max_perc=config.max_perc,
|
max_perc=config.max_perc,
|
||||||
max_masks=config.max_masks,
|
max_masks=config.max_masks,
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig):
|
|||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,14 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
|
from soundevent.data import PathLike
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model, build_model
|
||||||
|
from batdetect2.train.config import FullTrainingConfig
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.typing import ModelOutput, TrainExample
|
from batdetect2.typing import ModelOutput, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
config: FullTrainingConfig,
|
||||||
loss: torch.nn.Module,
|
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
|
model: Optional[Model] = None,
|
||||||
|
loss: Optional[torch.nn.Module] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
|
if loss is None:
|
||||||
|
loss = build_loss(self.config.train.loss)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
model = build_model(self.config)
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.model = model
|
self.model = model
|
||||||
self.save_hyperparameters(logger=False)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
|
||||||
return self.model(spec)
|
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample):
|
def training_step(self, batch: TrainExample):
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule):
|
|||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_checkpoint(
|
||||||
|
path: PathLike,
|
||||||
|
) -> Tuple[Model, FullTrainingConfig]:
|
||||||
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
|
return module.model, module.config
|
||||||
|
|||||||
@ -5,10 +5,11 @@ import numpy as np
|
|||||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: str = "logs"
|
DEFAULT_LOGS_DIR: str = "outputs"
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseConfig):
|
class DVCLiveConfig(BaseConfig):
|
||||||
@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig):
|
|||||||
class TensorBoardLoggerConfig(BaseConfig):
|
class TensorBoardLoggerConfig(BaseConfig):
|
||||||
logger_type: Literal["tensorboard"] = "tensorboard"
|
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||||
save_dir: str = DEFAULT_LOGS_DIR
|
save_dir: str = DEFAULT_LOGS_DIR
|
||||||
name: Optional[str] = "default"
|
name: Optional[str] = "logs"
|
||||||
version: Optional[str] = None
|
version: Optional[str] = None
|
||||||
log_graph: bool = False
|
log_graph: bool = False
|
||||||
|
|
||||||
@ -57,7 +58,10 @@ LoggerConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
def create_dvclive_logger(
|
||||||
|
config: DVCLiveConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
|||||||
) from error
|
) from error
|
||||||
|
|
||||||
return DVCLiveLogger(
|
return DVCLiveLogger(
|
||||||
dir=config.dir,
|
dir=log_dir if log_dir is not None else config.dir,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
def create_csv_logger(
|
||||||
|
config: CSVLoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
return CSVLogger(
|
return CSVLogger(
|
||||||
save_dir=config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
def create_tensorboard_logger(
|
||||||
|
config: TensorBoardLoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
return TensorBoardLogger(
|
return TensorBoardLogger(
|
||||||
save_dir=config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
log_graph=config.log_graph,
|
log_graph=config.log_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
def create_mlflow_logger(
|
||||||
|
config: MLFlowLoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
|||||||
return MLFlowLogger(
|
return MLFlowLogger(
|
||||||
experiment_name=config.experiment_name,
|
experiment_name=config.experiment_name,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
save_dir=config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
tracking_uri=config.tracking_uri,
|
tracking_uri=config.tracking_uri,
|
||||||
tags=config.tags,
|
tags=config.tags,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
@ -126,7 +139,10 @@ LOGGER_FACTORY = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_logger(config: LoggerConfig) -> Logger:
|
def build_logger(
|
||||||
|
config: LoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
"""
|
"""
|
||||||
Creates a logger instance from a validated Pydantic config object.
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
"""
|
"""
|
||||||
@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger:
|
|||||||
|
|
||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config)
|
return creation_func(config, log_dir=log_dir)
|
||||||
|
|
||||||
|
|
||||||
def get_image_plotter(logger: Logger):
|
def get_image_plotter(logger: Logger):
|
||||||
|
|||||||
@ -1,243 +0,0 @@
|
|||||||
"""Preprocesses datasets for BatDetect2 model training."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, List, Optional, Sequence, TypedDict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.utils.data
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.data.datasets import Dataset
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
|
||||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
|
||||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
from batdetect2.typing.train import PreprocessedExample
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"preprocess_annotations",
|
|
||||||
"generate_train_example",
|
|
||||||
"preprocess_dataset",
|
|
||||||
"TrainPreprocessConfig",
|
|
||||||
"load_train_preprocessing_config",
|
|
||||||
"save_preprocessed_example",
|
|
||||||
"load_preprocessed_example",
|
|
||||||
]
|
|
||||||
|
|
||||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
|
||||||
"""Type alias for a function that generates an output filename."""
|
|
||||||
|
|
||||||
|
|
||||||
class TrainPreprocessConfig(BaseConfig):
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_train_preprocessing_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> TrainPreprocessConfig:
|
|
||||||
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
|
||||||
dataset: Dataset,
|
|
||||||
config: TrainPreprocessConfig,
|
|
||||||
output: Path,
|
|
||||||
max_workers: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
targets = build_targets(config=config.targets)
|
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
|
||||||
labeller = build_clip_labeler(
|
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
config=config.labels,
|
|
||||||
)
|
|
||||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
|
||||||
|
|
||||||
if not output.exists():
|
|
||||||
logger.debug("Creating directory {directory}", directory=output)
|
|
||||||
output.mkdir(parents=True)
|
|
||||||
|
|
||||||
preprocess_annotations(
|
|
||||||
dataset,
|
|
||||||
output_dir=output,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
labeller=labeller,
|
|
||||||
max_workers=max_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Example(TypedDict):
|
|
||||||
audio: torch.Tensor
|
|
||||||
spectrogram: torch.Tensor
|
|
||||||
detection_heatmap: torch.Tensor
|
|
||||||
class_heatmap: torch.Tensor
|
|
||||||
size_heatmap: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
labeller: ClipLabeller,
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
"""Generate a complete training example for one annotation."""
|
|
||||||
wave = torch.tensor(
|
|
||||||
audio_loader.load_clip(clip_annotation.clip)
|
|
||||||
).unsqueeze(0)
|
|
||||||
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
|
|
||||||
heatmaps = labeller(clip_annotation, spectrogram)
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=wave,
|
|
||||||
spectrogram=spectrogram,
|
|
||||||
detection_heatmap=heatmaps.detection,
|
|
||||||
class_heatmap=heatmaps.classes,
|
|
||||||
size_heatmap=heatmaps.size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clips: Dataset,
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
labeller: ClipLabeller,
|
|
||||||
filename_fn: FilenameFn,
|
|
||||||
output_dir: Path,
|
|
||||||
force: bool = False,
|
|
||||||
):
|
|
||||||
self.clips = clips
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.labeller = labeller
|
|
||||||
self.filename_fn = filename_fn
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.force = force
|
|
||||||
|
|
||||||
def __getitem__(self, idx) -> int:
|
|
||||||
clip_annotation = self.clips[idx]
|
|
||||||
|
|
||||||
filename = self.filename_fn(clip_annotation)
|
|
||||||
|
|
||||||
path = self.output_dir / filename
|
|
||||||
|
|
||||||
if path.exists() and not self.force:
|
|
||||||
return idx
|
|
||||||
|
|
||||||
if not path.parent.exists():
|
|
||||||
path.parent.mkdir()
|
|
||||||
|
|
||||||
example = generate_train_example(
|
|
||||||
clip_annotation,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
labeller=self.labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_preprocessed_example(example, clip_annotation, path)
|
|
||||||
|
|
||||||
return idx
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.clips)
|
|
||||||
|
|
||||||
|
|
||||||
def save_preprocessed_example(
|
|
||||||
example: PreprocessedExample,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
path: data.PathLike,
|
|
||||||
) -> None:
|
|
||||||
np.savez_compressed(
|
|
||||||
path,
|
|
||||||
audio=example.audio.numpy(),
|
|
||||||
spectrogram=example.spectrogram.numpy(),
|
|
||||||
detection_heatmap=example.detection_heatmap.numpy(),
|
|
||||||
class_heatmap=example.class_heatmap.numpy(),
|
|
||||||
size_heatmap=example.size_heatmap.numpy(),
|
|
||||||
clip_annotation=clip_annotation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
|
||||||
item = np.load(path, mmap_mode="r+")
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=torch.tensor(item["audio"]),
|
|
||||||
spectrogram=torch.tensor(item["spectrogram"]),
|
|
||||||
size_heatmap=torch.tensor(item["size_heatmap"]),
|
|
||||||
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
|
||||||
class_heatmap=torch.tensor(item["class_heatmap"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def list_preprocessed_files(
|
|
||||||
directory: data.PathLike, extension: str = ".npz"
|
|
||||||
) -> List[Path]:
|
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
|
||||||
|
|
||||||
|
|
||||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
|
||||||
"""Generate a default output filename based on the annotation UUID."""
|
|
||||||
return f"{clip_annotation.uuid}"
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_annotations(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
output_dir: data.PathLike,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
labeller: ClipLabeller,
|
|
||||||
filename_fn: FilenameFn = _get_filename,
|
|
||||||
max_workers: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
|
|
||||||
if not output_dir.is_dir():
|
|
||||||
logger.info(
|
|
||||||
"Creating output directory: {output_dir}", output_dir=output_dir
|
|
||||||
)
|
|
||||||
output_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
|
|
||||||
num_annotations=len(clip_annotations),
|
|
||||||
max_workers=max_workers or "all available",
|
|
||||||
)
|
|
||||||
|
|
||||||
if max_workers is None:
|
|
||||||
max_workers = os.cpu_count() or 0
|
|
||||||
|
|
||||||
dataset = PreprocessingDataset(
|
|
||||||
clips=list(clip_annotations),
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
labeller=labeller,
|
|
||||||
output_dir=Path(output_dir),
|
|
||||||
filename_fn=filename_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
loader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=max_workers,
|
|
||||||
prefetch_factor=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in tqdm(loader, total=len(dataset)):
|
|
||||||
pass
|
|
||||||
@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
|
|||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, build_model
|
|
||||||
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
RandomAudioSource,
|
RandomAudioSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
|||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
@ -54,19 +53,21 @@ def train(
|
|||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
):
|
):
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
model = build_model(config=config)
|
targets = build_targets(config.targets)
|
||||||
|
|
||||||
trainer = build_trainer(config, targets=model.targets)
|
preprocessor = build_preprocessor(config.preprocess)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||||
|
|
||||||
labeller = build_clip_labeler(
|
labeller = build_clip_labeler(
|
||||||
model.targets,
|
targets,
|
||||||
min_freq=model.preprocessor.min_freq,
|
min_freq=preprocessor.min_freq,
|
||||||
max_freq=model.preprocessor.max_freq,
|
max_freq=preprocessor.max_freq,
|
||||||
config=config.train.labels,
|
config=config.train.labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ def train(
|
|||||||
train_annotations,
|
train_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=build_preprocessor(config.preprocess),
|
preprocessor=preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
@ -84,7 +85,7 @@ def train(
|
|||||||
val_annotations,
|
val_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=build_preprocessor(config.preprocess),
|
preprocessor=preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
@ -97,11 +98,17 @@ def train(
|
|||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
else:
|
else:
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model,
|
|
||||||
config,
|
config,
|
||||||
batches_per_epoch=len(train_dataloader),
|
t_max=config.train.t_max * len(train_dataloader),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trainer = build_trainer(
|
||||||
|
config,
|
||||||
|
targets=targets,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
log_dir=log_dir,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
@ -112,16 +119,14 @@ def train(
|
|||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model: Model,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
config: FullTrainingConfig,
|
t_max: int = 200,
|
||||||
batches_per_epoch: int,
|
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
loss = build_loss(config=config.train.loss)
|
config = config or FullTrainingConfig()
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model=model,
|
config=config,
|
||||||
loss=loss,
|
|
||||||
learning_rate=config.train.learning_rate,
|
learning_rate=config.train.learning_rate,
|
||||||
t_max=config.train.t_max * batches_per_epoch,
|
t_max=t_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -129,10 +134,14 @@ def build_trainer_callbacks(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: EvaluationConfig,
|
config: EvaluationConfig,
|
||||||
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
) -> List[Callback]:
|
) -> List[Callback]:
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
checkpoint_dir = "outputs/checkpoints"
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath="outputs/checkpoints",
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
monitor="total_loss/val",
|
monitor="total_loss/val",
|
||||||
),
|
),
|
||||||
@ -153,15 +162,22 @@ def build_trainer_callbacks(
|
|||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: FullTrainingConfig,
|
conf: FullTrainingConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.train.trainer
|
trainer_conf = conf.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
train_logger = build_logger(conf.train.logger)
|
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
|
||||||
|
|
||||||
train_logger.log_hyperparams(conf.model_dump(mode="json"))
|
train_logger.log_hyperparams(
|
||||||
|
conf.model_dump(
|
||||||
|
mode="json",
|
||||||
|
exclude_none=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
@ -170,6 +186,7 @@ def build_trainer(
|
|||||||
targets,
|
targets,
|
||||||
config=conf.evaluation,
|
config=conf.evaluation,
|
||||||
preprocessor=build_preprocessor(conf.preprocess),
|
preprocessor=build_preprocessor(conf.preprocess),
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -67,8 +67,7 @@ class TargetProtocol(Protocol):
|
|||||||
This protocol outlines the standard attributes and methods for an object
|
This protocol outlines the standard attributes and methods for an object
|
||||||
that encapsulates the complete, configured process for handling sound event
|
that encapsulates the complete, configured process for handling sound event
|
||||||
annotations (both tags and geometry). It defines how to:
|
annotations (both tags and geometry). It defines how to:
|
||||||
- Filter relevant annotations.
|
- Select relevant annotations.
|
||||||
- Transform annotation tags.
|
|
||||||
- Encode an annotation into a specific target class name.
|
- Encode an annotation into a specific target class name.
|
||||||
- Decode a class name back into representative tags.
|
- Decode a class name back into representative tags.
|
||||||
- Extract a target reference position from an annotation's geometry (ROI).
|
- Extract a target reference position from an annotation's geometry (ROI).
|
||||||
@ -121,26 +120,6 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def transform(
|
|
||||||
self,
|
|
||||||
sound_event: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
"""Apply tag transformations to an annotation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation whose tags should be transformed.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.SoundEventAnnotation
|
|
||||||
A new annotation object with the transformed tags. Implementations
|
|
||||||
should return the original annotation object if no transformations
|
|
||||||
were configured.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def encode_class(
|
def encode_class(
|
||||||
self,
|
self,
|
||||||
sound_event: data.SoundEventAnnotation,
|
sound_event: data.SoundEventAnnotation,
|
||||||
@ -248,3 +227,70 @@ class TargetProtocol(Protocol):
|
|||||||
if reconstruction fails based on the configured position type.
|
if reconstruction fails based on the configured position type.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ROITargetMapper(Protocol):
|
||||||
|
"""Protocol defining the interface for ROI-to-target mapping.
|
||||||
|
|
||||||
|
Specifies the `encode` and `decode` methods required for converting a
|
||||||
|
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||||
|
position and a size vector) and for recovering an approximate ROI from that
|
||||||
|
representation.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
A list containing the names of the dimensions in the `Size` array
|
||||||
|
returned by `encode` and expected by `decode`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names: List[str]
|
||||||
|
|
||||||
|
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||||
|
"""Encode a SoundEvent's geometry into a position and size.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEvent
|
||||||
|
The input sound event, which must have a geometry attribute.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[Position, Size]
|
||||||
|
A tuple containing:
|
||||||
|
- The reference position as (time, frequency) coordinates.
|
||||||
|
- A NumPy array with the calculated size dimensions.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the sound event does not have a geometry.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||||
|
"""Decode a position and size back into a geometric ROI.
|
||||||
|
|
||||||
|
Performs the inverse mapping: takes a reference position and size
|
||||||
|
dimensions and reconstructs a geometric representation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
position : Position
|
||||||
|
The reference position (time, frequency).
|
||||||
|
size : Size
|
||||||
|
NumPy array containing the size dimensions, matching the order
|
||||||
|
and meaning specified by `dimension_names`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Geometry
|
||||||
|
The reconstructed geometry, typically a `BoundingBox`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the `size` array has an unexpected shape or if reconstruction
|
||||||
|
fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@ -15,13 +15,10 @@ from batdetect2.preprocess import build_preprocessor
|
|||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
TermRegistry,
|
|
||||||
build_targets,
|
build_targets,
|
||||||
call_type,
|
call_type,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
from batdetect2.targets.classes import TargetClassConfig
|
||||||
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
|
||||||
from batdetect2.targets.terms import TagInfo
|
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
@ -355,18 +352,6 @@ def create_annotation_project():
|
|||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_term_registry() -> TermRegistry:
|
|
||||||
"""Fixture for a sample TermRegistry."""
|
|
||||||
registry = TermRegistry()
|
|
||||||
registry.add_custom_term("class")
|
|
||||||
registry.add_custom_term("order")
|
|
||||||
registry.add_custom_term("species")
|
|
||||||
registry.add_custom_term("call_type")
|
|
||||||
registry.add_custom_term("quality")
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_preprocessor() -> PreprocessorProtocol:
|
def sample_preprocessor() -> PreprocessorProtocol:
|
||||||
return build_preprocessor()
|
return build_preprocessor()
|
||||||
@ -378,56 +363,45 @@ def sample_audio_loader() -> AudioLoader:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def bat_tag() -> TagInfo:
|
def bat_tag() -> data.Tag:
|
||||||
return TagInfo(key="class", value="bat")
|
return data.Tag(key="class", value="bat")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def noise_tag() -> TagInfo:
|
def noise_tag() -> data.Tag:
|
||||||
return TagInfo(key="class", value="noise")
|
return data.Tag(key="class", value="noise")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def myomyo_tag() -> TagInfo:
|
def myomyo_tag() -> data.Tag:
|
||||||
return TagInfo(key="species", value="Myotis myotis")
|
return data.Tag(key="species", value="Myotis myotis")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pippip_tag() -> TagInfo:
|
def pippip_tag() -> data.Tag:
|
||||||
return TagInfo(key="species", value="Pipistrellus pipistrellus")
|
return data.Tag(key="species", value="Pipistrellus pipistrellus")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_target_config(
|
def sample_target_config(
|
||||||
sample_term_registry: TermRegistry,
|
bat_tag: data.Tag,
|
||||||
bat_tag: TagInfo,
|
myomyo_tag: data.Tag,
|
||||||
noise_tag: TagInfo,
|
pippip_tag: data.Tag,
|
||||||
myomyo_tag: TagInfo,
|
|
||||||
pippip_tag: TagInfo,
|
|
||||||
) -> TargetConfig:
|
) -> TargetConfig:
|
||||||
return TargetConfig(
|
return TargetConfig(
|
||||||
filtering=FilterConfig(
|
detection_target=TargetClassConfig(name="bat", tags=[bat_tag]),
|
||||||
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
|
classification_targets=[
|
||||||
),
|
TargetClassConfig(name="pippip", tags=[pippip_tag]),
|
||||||
classes=ClassesConfig(
|
TargetClassConfig(name="myomyo", tags=[myomyo_tag]),
|
||||||
classes=[
|
],
|
||||||
TargetClass(name="pippip", tags=[pippip_tag]),
|
|
||||||
TargetClass(name="myomyo", tags=[myomyo_tag]),
|
|
||||||
],
|
|
||||||
generic_class=[bat_tag],
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_targets(
|
def sample_targets(
|
||||||
sample_target_config: TargetConfig,
|
sample_target_config: TargetConfig,
|
||||||
sample_term_registry: TermRegistry,
|
|
||||||
) -> TargetProtocol:
|
) -> TargetProtocol:
|
||||||
return build_targets(
|
return build_targets(sample_target_config)
|
||||||
sample_target_config,
|
|
||||||
term_registry=sample_term_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -443,10 +417,8 @@ def sample_labeller(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_clipper(
|
def sample_clipper() -> ClipperProtocol:
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
return build_clipper()
|
||||||
) -> ClipperProtocol:
|
|
||||||
return build_clipper(preprocessor=sample_preprocessor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
0
tests/test_data/test_transforms/__init__.py
Normal file
0
tests/test_data/test_transforms/__init__.py
Normal file
516
tests/test_data/test_transforms/test_conditions.py
Normal file
516
tests/test_data/test_transforms/test_conditions.py
Normal file
@ -0,0 +1,516 @@
|
|||||||
|
import textwrap
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.conditions import (
|
||||||
|
SoundEventConditionConfig,
|
||||||
|
build_sound_event_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_condition_from_str(content):
|
||||||
|
content = textwrap.dedent(content)
|
||||||
|
content = yaml.safe_load(content)
|
||||||
|
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
||||||
|
return build_sound_event_condition(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_tag(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_all_tags(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: has_all_tags
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- key: event
|
||||||
|
value: Echolocation
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||||
|
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||||
|
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||||
|
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||||
|
data.Tag(key="sex", value="Female"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_any_tags(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: has_any_tag
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- key: event
|
||||||
|
value: Echolocation
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||||
|
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||||
|
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
||||||
|
data.Tag(key="event", value="Social"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_not(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: not
|
||||||
|
condition:
|
||||||
|
name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
||||||
|
data.Tag(key="event", value="Echolocation"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_duration(recording: data.Recording):
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se2 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se3 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: lt
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
assert condition(se1)
|
||||||
|
assert not condition(se2)
|
||||||
|
assert not condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: lte
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert condition(se1)
|
||||||
|
assert condition(se2)
|
||||||
|
assert not condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: gt
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se1)
|
||||||
|
assert not condition(se2)
|
||||||
|
assert condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: gte
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se1)
|
||||||
|
assert condition(se2)
|
||||||
|
assert condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: eq
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se1)
|
||||||
|
assert condition(se2)
|
||||||
|
assert not condition(se3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency(recording: data.Recording):
|
||||||
|
se12 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se13 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se14 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se24 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se34 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: lt
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
assert condition(se12)
|
||||||
|
assert not condition(se13)
|
||||||
|
assert not condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: lte
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert condition(se12)
|
||||||
|
assert condition(se13)
|
||||||
|
assert not condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: gt
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se12)
|
||||||
|
assert not condition(se13)
|
||||||
|
assert condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: gte
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se12)
|
||||||
|
assert condition(se13)
|
||||||
|
assert condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: eq
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se12)
|
||||||
|
assert condition(se13)
|
||||||
|
assert not condition(se14)
|
||||||
|
|
||||||
|
# LOW
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: lt
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
assert condition(se14)
|
||||||
|
assert not condition(se24)
|
||||||
|
assert not condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: lte
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert condition(se14)
|
||||||
|
assert condition(se24)
|
||||||
|
assert not condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: gt
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se14)
|
||||||
|
assert not condition(se24)
|
||||||
|
assert condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: gte
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se14)
|
||||||
|
assert condition(se24)
|
||||||
|
assert condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: eq
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se14)
|
||||||
|
assert condition(se24)
|
||||||
|
assert not condition(se34)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: eq
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 3]),
|
||||||
|
recording=recording,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeStamp(coordinates=3),
|
||||||
|
recording=recording,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_tags_fails_if_empty():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
build_condition_from_str("""
|
||||||
|
name: has_tags
|
||||||
|
tags: []
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: eq
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
|
||||||
|
def test_duration_is_false_if_no_geometry(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: eq
|
||||||
|
seconds: 1
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_of(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: all_of
|
||||||
|
conditions:
|
||||||
|
- name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- name: duration
|
||||||
|
operator: lt
|
||||||
|
seconds: 1
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_of(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: any_of
|
||||||
|
conditions:
|
||||||
|
- name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- name: duration
|
||||||
|
operator: lt
|
||||||
|
seconds: 1
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
@ -3,26 +3,15 @@ from typing import Callable
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.terms import get_term
|
from soundevent.terms import get_term
|
||||||
|
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
DEFAULT_SPECIES_LIST,
|
TargetClassConfig,
|
||||||
ClassesConfig,
|
|
||||||
TargetClass,
|
|
||||||
_get_default_class_name,
|
|
||||||
_get_default_classes,
|
|
||||||
build_generic_class_tags,
|
|
||||||
build_sound_event_decoder,
|
build_sound_event_decoder,
|
||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
is_target_class,
|
|
||||||
load_classes_config,
|
|
||||||
load_decoder_from_config,
|
|
||||||
load_encoder_from_config,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import TagInfo
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -33,8 +22,8 @@ def sample_annotation(
|
|||||||
return data.SoundEventAnnotation(
|
return data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
data.Tag(key="species", value="Pipistrellus pipistrellus"),
|
||||||
data.Tag(key="quality", value="Good"), # type: ignore
|
data.Tag(key="quality", value="Good"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -51,291 +40,71 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
|||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
def test_target_class_creation():
|
|
||||||
target_class = TargetClass(
|
|
||||||
name="pippip",
|
|
||||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
||||||
)
|
|
||||||
assert target_class.name == "pippip"
|
|
||||||
assert target_class.tags[0].key == "species"
|
|
||||||
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
|
|
||||||
assert target_class.match_type == "all"
|
|
||||||
|
|
||||||
|
|
||||||
def test_classes_config_creation():
|
|
||||||
target_class = TargetClass(
|
|
||||||
name="pippip",
|
|
||||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
||||||
)
|
|
||||||
config = ClassesConfig(classes=[target_class])
|
|
||||||
assert len(config.classes) == 1
|
|
||||||
assert config.classes[0].name == "pippip"
|
|
||||||
|
|
||||||
|
|
||||||
def test_classes_config_unique_names():
|
|
||||||
target_class1 = TargetClass(
|
|
||||||
name="pippip",
|
|
||||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
||||||
)
|
|
||||||
target_class2 = TargetClass(
|
|
||||||
name="myodau",
|
|
||||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
|
||||||
)
|
|
||||||
ClassesConfig(classes=[target_class1, target_class2]) # No error
|
|
||||||
|
|
||||||
|
|
||||||
def test_classes_config_non_unique_names():
|
|
||||||
target_class1 = TargetClass(
|
|
||||||
name="pippip",
|
|
||||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
||||||
)
|
|
||||||
target_class2 = TargetClass(
|
|
||||||
name="pippip",
|
|
||||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
|
||||||
)
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
ClassesConfig(classes=[target_class1, target_class2])
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
|
|
||||||
yaml_content = """
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
"""
|
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
||||||
config = load_classes_config(temp_yaml_path)
|
|
||||||
assert len(config.classes) == 1
|
|
||||||
assert config.classes[0].name == "pippip"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
|
||||||
yaml_content = """
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis daubentonii
|
|
||||||
"""
|
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
load_classes_config(temp_yaml_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_target_class_match_all(
|
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
|
||||||
):
|
|
||||||
tags = {
|
|
||||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
|
||||||
data.Tag(key="quality", value="Good"), # type: ignore
|
|
||||||
}
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
|
||||||
|
|
||||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
|
||||||
|
|
||||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_target_class_match_any(
|
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
|
||||||
):
|
|
||||||
tags = {
|
|
||||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
|
||||||
data.Tag(key="quality", value="Good"), # type: ignore
|
|
||||||
}
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
|
||||||
|
|
||||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
|
||||||
|
|
||||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
|
||||||
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_class_names_from_config():
|
def test_get_class_names_from_config():
|
||||||
target_class1 = TargetClass(
|
target_class1 = TargetClassConfig(
|
||||||
name="pippip",
|
name="pippip",
|
||||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||||
)
|
)
|
||||||
target_class2 = TargetClass(
|
target_class2 = TargetClassConfig(
|
||||||
name="myodau",
|
name="myodau",
|
||||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
tags=[data.Tag(key="species", value="Myotis daubentonii")],
|
||||||
)
|
)
|
||||||
config = ClassesConfig(classes=[target_class1, target_class2])
|
names = get_class_names_from_config([target_class1, target_class2])
|
||||||
names = get_class_names_from_config(config)
|
|
||||||
assert names == ["pippip", "myodau"]
|
assert names == ["pippip", "myodau"]
|
||||||
|
|
||||||
|
|
||||||
def test_build_encoder_from_config(
|
def test_build_encoder_from_config(
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
):
|
):
|
||||||
config = ClassesConfig(
|
classes = [
|
||||||
classes=[
|
TargetClassConfig(
|
||||||
TargetClass(
|
name="pippip",
|
||||||
name="pippip",
|
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||||
tags=[
|
)
|
||||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
]
|
||||||
],
|
encoder = build_sound_event_encoder(classes)
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
encoder = build_sound_event_encoder(config)
|
|
||||||
result = encoder(sample_annotation)
|
result = encoder(sample_annotation)
|
||||||
assert result == "pippip"
|
assert result == "pippip"
|
||||||
|
|
||||||
config = ClassesConfig(classes=[])
|
classes = []
|
||||||
encoder = build_sound_event_encoder(config)
|
encoder = build_sound_event_encoder(classes)
|
||||||
result = encoder(sample_annotation)
|
result = encoder(sample_annotation)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def test_load_encoder_from_config_valid(
|
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
|
||||||
create_temp_yaml: Callable[[str], Path],
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
"""
|
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
||||||
encoder = load_encoder_from_config(temp_yaml_path)
|
|
||||||
# We cannot directly compare the function, so we test it.
|
|
||||||
result = encoder(sample_annotation) # type: ignore
|
|
||||||
assert result == "pippip"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_encoder_from_config_invalid(
|
|
||||||
create_temp_yaml: Callable[[str], Path],
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: invalid_key
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
"""
|
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
load_encoder_from_config(temp_yaml_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_default_class_name():
|
|
||||||
assert _get_default_class_name("Myotis daubentonii") == "myodau"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_default_classes():
|
|
||||||
default_classes = _get_default_classes()
|
|
||||||
assert len(default_classes) == len(DEFAULT_SPECIES_LIST)
|
|
||||||
first_class = default_classes[0]
|
|
||||||
assert isinstance(first_class, TargetClass)
|
|
||||||
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
|
|
||||||
assert first_class.tags[0].key == "class"
|
|
||||||
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_decoder_from_config():
|
def test_build_decoder_from_config():
|
||||||
config = ClassesConfig(
|
classes = [
|
||||||
classes=[
|
TargetClassConfig(
|
||||||
TargetClass(
|
name="pippip",
|
||||||
name="pippip",
|
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||||
tags=[
|
assign_tags=[data.Tag(key="call_type", value="Echolocation")],
|
||||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
)
|
||||||
],
|
]
|
||||||
output_tags=[TagInfo(key="call_type", value="Echolocation")],
|
decoder = build_sound_event_decoder(classes)
|
||||||
)
|
|
||||||
],
|
|
||||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
|
||||||
)
|
|
||||||
decoder = build_sound_event_decoder(config)
|
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
assert len(tags) == 1
|
assert len(tags) == 1
|
||||||
assert tags[0].term == get_term("event")
|
assert tags[0].term == get_term("event")
|
||||||
assert tags[0].value == "Echolocation"
|
assert tags[0].value == "Echolocation"
|
||||||
|
|
||||||
# Test when output_tags is None, should fall back to tags
|
# Test when output_tags is None, should fall back to tags
|
||||||
config = ClassesConfig(
|
classes = [
|
||||||
classes=[
|
TargetClassConfig(
|
||||||
TargetClass(
|
name="pippip",
|
||||||
name="pippip",
|
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
||||||
tags=[
|
)
|
||||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
]
|
||||||
],
|
decoder = build_sound_event_decoder(classes)
|
||||||
)
|
|
||||||
],
|
|
||||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
|
||||||
)
|
|
||||||
decoder = build_sound_event_decoder(config)
|
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
assert len(tags) == 1
|
assert len(tags) == 1
|
||||||
assert tags[0].term == get_term("species")
|
assert tags[0].term == get_term("species")
|
||||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
# Test raise_on_unmapped=True
|
# Test raise_on_unmapped=True
|
||||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
|
decoder = build_sound_event_decoder(classes, raise_on_unmapped=True)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
decoder("unknown_class")
|
decoder("unknown_class")
|
||||||
|
|
||||||
# Test raise_on_unmapped=False
|
# Test raise_on_unmapped=False
|
||||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
|
decoder = build_sound_event_decoder(classes, raise_on_unmapped=False)
|
||||||
tags = decoder("unknown_class")
|
tags = decoder("unknown_class")
|
||||||
assert len(tags) == 0
|
assert len(tags) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_load_decoder_from_config_valid(
|
|
||||||
create_temp_yaml: Callable[[str], Path],
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
output_tags:
|
|
||||||
- key: call_type
|
|
||||||
value: Echolocation
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
"""
|
|
||||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
||||||
decoder = load_decoder_from_config(
|
|
||||||
temp_yaml_path,
|
|
||||||
)
|
|
||||||
tags = decoder("pippip")
|
|
||||||
assert len(tags) == 1
|
|
||||||
assert tags[0].term == get_term("call_type")
|
|
||||||
assert tags[0].value == "Echolocation"
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_generic_class_tags_from_config():
|
|
||||||
config = ClassesConfig(
|
|
||||||
classes=[
|
|
||||||
TargetClass(
|
|
||||||
name="pippip",
|
|
||||||
tags=[
|
|
||||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
|
||||||
],
|
|
||||||
)
|
|
||||||
],
|
|
||||||
generic_class=[
|
|
||||||
TagInfo(key="order", value="Chiroptera"),
|
|
||||||
TagInfo(key="call_type", value="Echolocation"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
generic_tags = build_generic_class_tags(config)
|
|
||||||
assert len(generic_tags) == 2
|
|
||||||
assert generic_tags[0].term == get_term("order")
|
|
||||||
assert generic_tags[0].value == "Chiroptera"
|
|
||||||
assert generic_tags[1].term == get_term("call_type")
|
|
||||||
assert generic_tags[1].value == "Echolocation"
|
|
||||||
|
|||||||
@ -1,210 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Callable, List, Set
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.targets.filtering import (
|
|
||||||
FilterConfig,
|
|
||||||
FilterRule,
|
|
||||||
build_filter_from_rule,
|
|
||||||
build_sound_event_filter,
|
|
||||||
contains_tags,
|
|
||||||
does_not_have_tags,
|
|
||||||
equal_tags,
|
|
||||||
has_any_tag,
|
|
||||||
load_filter_config,
|
|
||||||
load_filter_from_config,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.terms import TagInfo, generic_class
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def create_annotation(
|
|
||||||
sound_event: data.SoundEvent,
|
|
||||||
) -> Callable[[List[str]], data.SoundEventAnnotation]:
|
|
||||||
"""Helper function to create a SoundEventAnnotation with given tags."""
|
|
||||||
|
|
||||||
def factory(tags: List[str]) -> data.SoundEventAnnotation:
|
|
||||||
return data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(
|
|
||||||
term=generic_class,
|
|
||||||
value=tag,
|
|
||||||
)
|
|
||||||
for tag in tags
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return factory
|
|
||||||
|
|
||||||
|
|
||||||
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
|
|
||||||
"""Helper function to create a set of data.Tag objects from a list of strings."""
|
|
||||||
return {
|
|
||||||
data.Tag(
|
|
||||||
term=generic_class,
|
|
||||||
value=tag,
|
|
||||||
)
|
|
||||||
for tag in tags
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_any_tag(create_annotation):
|
|
||||||
annotation = create_annotation(["tag1", "tag2"])
|
|
||||||
tags = create_tag_set(["tag1", "tag3"])
|
|
||||||
assert has_any_tag(annotation, tags) is True
|
|
||||||
|
|
||||||
annotation = create_annotation(["tag2", "tag4"])
|
|
||||||
tags = create_tag_set(["tag1", "tag3"])
|
|
||||||
assert has_any_tag(annotation, tags) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_contains_tags(create_annotation):
|
|
||||||
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
|
||||||
tags = create_tag_set(["tag1", "tag2"])
|
|
||||||
assert contains_tags(annotation, tags) is True
|
|
||||||
|
|
||||||
annotation = create_annotation(["tag1", "tag2"])
|
|
||||||
tags = create_tag_set(["tag1", "tag2", "tag3"])
|
|
||||||
assert contains_tags(annotation, tags) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_does_not_have_tags(create_annotation):
|
|
||||||
annotation = create_annotation(["tag1", "tag2"])
|
|
||||||
tags = create_tag_set(["tag3", "tag4"])
|
|
||||||
assert does_not_have_tags(annotation, tags) is True
|
|
||||||
|
|
||||||
annotation = create_annotation(["tag1", "tag2"])
|
|
||||||
tags = create_tag_set(["tag1", "tag3"])
|
|
||||||
assert does_not_have_tags(annotation, tags) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_equal_tags(create_annotation):
|
|
||||||
annotation = create_annotation(["tag1", "tag2"])
|
|
||||||
tags = create_tag_set(["tag1", "tag2"])
|
|
||||||
assert equal_tags(annotation, tags) is True
|
|
||||||
|
|
||||||
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
|
||||||
tags = create_tag_set(["tag1", "tag2"])
|
|
||||||
assert equal_tags(annotation, tags) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_filter_from_rule():
|
|
||||||
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
|
|
||||||
build_filter_from_rule(rule_any)
|
|
||||||
|
|
||||||
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
|
|
||||||
build_filter_from_rule(rule_all)
|
|
||||||
|
|
||||||
rule_exclude = FilterRule(
|
|
||||||
match_type="exclude", tags=[TagInfo(value="tag1")]
|
|
||||||
)
|
|
||||||
build_filter_from_rule(rule_exclude)
|
|
||||||
|
|
||||||
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
|
|
||||||
build_filter_from_rule(rule_equal)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
|
||||||
build_filter_from_rule(
|
|
||||||
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_filter_from_config(create_annotation):
|
|
||||||
config = FilterConfig(
|
|
||||||
rules=[
|
|
||||||
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
|
|
||||||
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
filter_from_config = build_sound_event_filter(config)
|
|
||||||
|
|
||||||
annotation_pass = create_annotation(["tag1", "tag2"])
|
|
||||||
assert filter_from_config(annotation_pass)
|
|
||||||
|
|
||||||
annotation_fail = create_annotation(["tag1"])
|
|
||||||
assert not filter_from_config(annotation_fail)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_filter_config(tmp_path: Path):
|
|
||||||
test_config_path = tmp_path / "filtering.yaml"
|
|
||||||
test_config_path.write_text(
|
|
||||||
"""
|
|
||||||
rules:
|
|
||||||
- match_type: any
|
|
||||||
tags:
|
|
||||||
- value: tag1
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
config = load_filter_config(test_config_path)
|
|
||||||
assert isinstance(config, FilterConfig)
|
|
||||||
assert len(config.rules) == 1
|
|
||||||
rule = config.rules[0]
|
|
||||||
assert rule.match_type == "any"
|
|
||||||
assert len(rule.tags) == 1
|
|
||||||
assert rule.tags[0].value == "tag1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_filter_from_config(tmp_path: Path, create_annotation):
|
|
||||||
test_config_path = tmp_path / "filtering.yaml"
|
|
||||||
test_config_path.write_text(
|
|
||||||
"""
|
|
||||||
rules:
|
|
||||||
- match_type: any
|
|
||||||
tags:
|
|
||||||
- value: tag1
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
filter_result = load_filter_from_config(test_config_path)
|
|
||||||
annotation = create_annotation(["tag1", "tag3"])
|
|
||||||
assert filter_result(annotation)
|
|
||||||
|
|
||||||
test_config_path = tmp_path / "filtering.yaml"
|
|
||||||
test_config_path.write_text(
|
|
||||||
"""
|
|
||||||
rules:
|
|
||||||
- match_type: any
|
|
||||||
tags:
|
|
||||||
- value: tag2
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
filter_result = load_filter_from_config(test_config_path)
|
|
||||||
annotation = create_annotation(["tag1", "tag3"])
|
|
||||||
assert filter_result(annotation) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_filtering_over_example_dataset(
|
|
||||||
example_annotations: List[data.ClipAnnotation],
|
|
||||||
):
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
clip1 = example_annotations[0]
|
|
||||||
clip2 = example_annotations[1]
|
|
||||||
clip3 = example_annotations[2]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
sum(
|
|
||||||
[targets.filter(sound_event) for sound_event in clip1.sound_events]
|
|
||||||
)
|
|
||||||
== 9
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
sum(
|
|
||||||
[targets.filter(sound_event) for sound_event in clip2.sound_events]
|
|
||||||
)
|
|
||||||
== 15
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
sum(
|
|
||||||
[targets.filter(sound_event) for sound_event in clip3.sound_events]
|
|
||||||
)
|
|
||||||
== 20
|
|
||||||
)
|
|
||||||
@ -8,6 +8,11 @@ from batdetect2.preprocess import (
|
|||||||
build_preprocessor,
|
build_preprocessor,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
ScaleAmplitudeConfig,
|
||||||
|
SpectralMeanSubstractionConfig,
|
||||||
|
SpectrogramConfig,
|
||||||
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
DEFAULT_ANCHOR,
|
DEFAULT_ANCHOR,
|
||||||
DEFAULT_FREQUENCY_SCALE,
|
DEFAULT_FREQUENCY_SCALE,
|
||||||
@ -548,14 +553,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
|||||||
|
|
||||||
# Instantiate the mapper.
|
# Instantiate the mapper.
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
PreprocessingConfig.model_validate(
|
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
|
||||||
{
|
|
||||||
"spectrogram": {
|
|
||||||
"pcen": None,
|
|
||||||
"spectral_mean_substraction": False,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
audio_loader = build_audio_loader()
|
audio_loader = build_audio_loader()
|
||||||
mapper = PeakEnergyBBoxMapper(
|
mapper = PeakEnergyBBoxMapper(
|
||||||
@ -597,14 +595,13 @@ def test_build_roi_mapper_for_anchor_bbox():
|
|||||||
|
|
||||||
def test_build_roi_mapper_for_peak_energy_bbox():
|
def test_build_roi_mapper_for_peak_energy_bbox():
|
||||||
# Given
|
# Given
|
||||||
preproc_config = PreprocessingConfig.model_validate(
|
preproc_config = PreprocessingConfig(
|
||||||
{
|
spectrogram=SpectrogramConfig(
|
||||||
"spectrogram": {
|
transforms=[
|
||||||
"pcen": None,
|
ScaleAmplitudeConfig(scale="db"),
|
||||||
"spectral_mean_substraction": True,
|
SpectralMeanSubstractionConfig(),
|
||||||
"scale": "dB",
|
]
|
||||||
}
|
),
|
||||||
}
|
|
||||||
)
|
)
|
||||||
config = PeakEnergyBBoxMapperConfig(
|
config = PeakEnergyBBoxMapperConfig(
|
||||||
loading_buffer=0.99,
|
loading_buffer=0.99,
|
||||||
|
|||||||
@ -1,46 +1,55 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from soundevent import data
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
from batdetect2.targets.terms import get_term_from_key
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_override_default_roi_mapper_per_class(
|
def test_can_override_default_roi_mapper_per_class(
|
||||||
create_temp_yaml: Callable[..., Path],
|
create_temp_yaml: Callable[..., Path],
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
sample_term_registry,
|
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
|
detection_target:
|
||||||
|
name: bat
|
||||||
|
match_if:
|
||||||
|
name: has_tag
|
||||||
|
tag:
|
||||||
|
key: order
|
||||||
|
value: Chiroptera
|
||||||
|
assign_tags:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
|
||||||
|
classification_targets:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
|
||||||
roi:
|
roi:
|
||||||
name: anchor_bbox
|
name: anchor_bbox
|
||||||
anchor: bottom-left
|
anchor: bottom-left
|
||||||
classes:
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
- name: myomyo
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
"""
|
"""
|
||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = load_target_config(config_path)
|
config = load_target_config(config_path)
|
||||||
targets = build_targets(config, term_registry=sample_term_registry)
|
targets = build_targets(config)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
species = terms.get_term("species")
|
||||||
|
assert species is not None
|
||||||
|
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
@ -62,38 +71,47 @@ def test_can_override_default_roi_mapper_per_class(
|
|||||||
# TODO: rename this test function
|
# TODO: rename this test function
|
||||||
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||||
create_temp_yaml,
|
create_temp_yaml,
|
||||||
sample_term_registry,
|
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
|
detection_target:
|
||||||
|
name: bat
|
||||||
|
match_if:
|
||||||
|
name: has_tag
|
||||||
|
tag:
|
||||||
|
key: order
|
||||||
|
value: Chiroptera
|
||||||
|
assign_tags:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
|
||||||
|
classification_targets:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
|
||||||
roi:
|
roi:
|
||||||
name: anchor_bbox
|
name: anchor_bbox
|
||||||
anchor: bottom-left
|
anchor: bottom-left
|
||||||
classes:
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
- name: myomyo
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
"""
|
"""
|
||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = load_target_config(config_path)
|
config = load_target_config(config_path)
|
||||||
targets = build_targets(config, term_registry=sample_term_registry)
|
targets = build_targets(config)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
species = terms.get_term("species")
|
||||||
|
assert species is not None
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
|
|||||||
@ -1,179 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.targets import terms
|
|
||||||
from batdetect2.targets.terms import (
|
|
||||||
TagInfo,
|
|
||||||
TermRegistry,
|
|
||||||
load_terms_from_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_initialization():
|
|
||||||
registry = TermRegistry()
|
|
||||||
assert registry._terms == {}
|
|
||||||
|
|
||||||
initial_terms = {
|
|
||||||
"test_term": data.Term(name="test", label="Test", definition="test")
|
|
||||||
}
|
|
||||||
registry = TermRegistry(terms=initial_terms)
|
|
||||||
assert registry._terms == initial_terms
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_add_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = data.Term(name="test", label="Test", definition="test")
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
assert registry._terms["test_key"] == term
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_get_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = data.Term(name="test", label="Test", definition="test")
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
retrieved_term = registry.get_term("test_key")
|
|
||||||
assert retrieved_term == term
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_add_custom_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = registry.add_custom_term(
|
|
||||||
"custom_key", name="custom", label="Custom", definition="A custom term"
|
|
||||||
)
|
|
||||||
assert registry._terms["custom_key"] == term
|
|
||||||
assert term.name == "custom"
|
|
||||||
assert term.label == "Custom"
|
|
||||||
assert term.definition == "A custom term"
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_add_duplicate_term():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term = data.Term(name="test", label="Test", definition="test")
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
registry.add_term("test_key", term)
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_get_term_not_found():
|
|
||||||
registry = TermRegistry()
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
registry.get_term("non_existent_key")
|
|
||||||
|
|
||||||
|
|
||||||
def test_term_registry_get_keys():
|
|
||||||
registry = TermRegistry()
|
|
||||||
term1 = data.Term(name="test1", label="Test1", definition="test")
|
|
||||||
term2 = data.Term(name="test2", label="Test2", definition="test")
|
|
||||||
registry.add_term("key1", term1)
|
|
||||||
registry.add_term("key2", term2)
|
|
||||||
keys = registry.get_keys()
|
|
||||||
assert set(keys) == {"key1", "key2"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_term_from_key():
|
|
||||||
term = terms.get_term_from_key("event")
|
|
||||||
assert term == terms.call_type
|
|
||||||
|
|
||||||
custom_registry = TermRegistry()
|
|
||||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
|
||||||
custom_registry.add_term("custom_key", custom_term)
|
|
||||||
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
|
|
||||||
assert term == custom_term
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_term_keys():
|
|
||||||
keys = terms.get_term_keys()
|
|
||||||
assert "event" in keys
|
|
||||||
assert "individual" in keys
|
|
||||||
assert terms.GENERIC_CLASS_KEY in keys
|
|
||||||
|
|
||||||
custom_registry = TermRegistry()
|
|
||||||
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
|
||||||
custom_registry.add_term("custom_key", custom_term)
|
|
||||||
keys = terms.get_term_keys(term_registry=custom_registry)
|
|
||||||
assert "custom_key" in keys
|
|
||||||
|
|
||||||
|
|
||||||
def test_tag_info_and_get_tag_from_info():
|
|
||||||
tag_info = TagInfo(value="Myotis myotis", key="event")
|
|
||||||
tag = terms.get_tag_from_info(tag_info)
|
|
||||||
assert tag.value == "Myotis myotis"
|
|
||||||
assert tag.term == terms.call_type
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_tag_from_info_key_not_found():
|
|
||||||
tag_info = TagInfo(value="test", key="non_existent_key")
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
terms.get_tag_from_info(tag_info)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config(tmp_path):
|
|
||||||
term_registry = TermRegistry()
|
|
||||||
config_data = {
|
|
||||||
"terms": [
|
|
||||||
{
|
|
||||||
"key": "species",
|
|
||||||
"name": "dwc:scientificName",
|
|
||||||
"label": "Scientific Name",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "my_custom_term",
|
|
||||||
"name": "soundevent:custom_term",
|
|
||||||
"definition": "Describes a specific project attribute",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
config_file = tmp_path / "config.yaml"
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
yaml.dump(config_data, f)
|
|
||||||
|
|
||||||
loaded_terms = load_terms_from_config(
|
|
||||||
config_file,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
assert "species" in loaded_terms
|
|
||||||
assert "my_custom_term" in loaded_terms
|
|
||||||
assert loaded_terms["species"].name == "dwc:scientificName"
|
|
||||||
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config_file_not_found():
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
load_terms_from_config("non_existent_file.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config_validation_error(tmp_path):
|
|
||||||
config_data = {
|
|
||||||
"terms": [
|
|
||||||
{
|
|
||||||
"key": "species",
|
|
||||||
"uri": "dwc:scientificName",
|
|
||||||
"label": 123,
|
|
||||||
}, # Invalid label type
|
|
||||||
]
|
|
||||||
}
|
|
||||||
config_file = tmp_path / "config.yaml"
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
yaml.dump(config_data, f)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
load_terms_from_config(config_file)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_terms_from_config_key_already_exists(tmp_path):
|
|
||||||
config_data = {
|
|
||||||
"terms": [
|
|
||||||
{
|
|
||||||
"key": "event",
|
|
||||||
"uri": "dwc:scientificName",
|
|
||||||
"label": "Scientific Name",
|
|
||||||
}, # Duplicate key
|
|
||||||
]
|
|
||||||
}
|
|
||||||
config_file = tmp_path / "config.yaml"
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
yaml.dump(config_data, f)
|
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
load_terms_from_config(config_file)
|
|
||||||
@ -1,363 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.targets import (
|
|
||||||
DeriveTagRule,
|
|
||||||
MapValueRule,
|
|
||||||
ReplaceRule,
|
|
||||||
TagInfo,
|
|
||||||
TransformConfig,
|
|
||||||
build_transformation_from_config,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.terms import TermRegistry
|
|
||||||
from batdetect2.targets.transform import (
|
|
||||||
DerivationRegistry,
|
|
||||||
build_transform_from_rule,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def term_registry():
|
|
||||||
return TermRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def derivation_registry():
|
|
||||||
return DerivationRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def term1(term_registry: TermRegistry) -> data.Term:
|
|
||||||
return term_registry.add_custom_term(key="term1")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def term2(term_registry: TermRegistry) -> data.Term:
|
|
||||||
return term_registry.add_custom_term(key="term2")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def annotation(
|
|
||||||
sound_event: data.SoundEvent,
|
|
||||||
term1: data.Term,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
return data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_map_value_rule(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
rule = MapValueRule(
|
|
||||||
rule_type="map_value",
|
|
||||||
source_term_key="term1",
|
|
||||||
value_mapping={"value1": "value2"},
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_map_value_rule_no_match(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
rule = MapValueRule(
|
|
||||||
rule_type="map_value",
|
|
||||||
source_term_key="term1",
|
|
||||||
value_mapping={"other_value": "value2"},
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_replace_rule(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term2: data.Term,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
rule = ReplaceRule(
|
|
||||||
rule_type="replace",
|
|
||||||
original=TagInfo(key="term1", value="value1"),
|
|
||||||
replacement=TagInfo(key="term2", value="value2"),
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
assert transformed_annotation.tags[0].term == term2
|
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_replace_rule_no_match(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
term2: data.Term,
|
|
||||||
):
|
|
||||||
rule = ReplaceRule(
|
|
||||||
rule_type="replace",
|
|
||||||
original=TagInfo(key="term1", value="wrong_value"),
|
|
||||||
replacement=TagInfo(key="term2", value="value2"),
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
assert transformed_annotation.tags[0].key == "term1"
|
|
||||||
assert transformed_annotation.tags[0].term != term2
|
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_transformation_from_config(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
):
|
|
||||||
config = TransformConfig(
|
|
||||||
rules=[
|
|
||||||
MapValueRule(
|
|
||||||
rule_type="map_value",
|
|
||||||
source_term_key="term1",
|
|
||||||
value_mapping={"value1": "value2"},
|
|
||||||
),
|
|
||||||
ReplaceRule(
|
|
||||||
rule_type="replace",
|
|
||||||
original=TagInfo(key="term2", value="value2"),
|
|
||||||
replacement=TagInfo(key="term3", value="value3"),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
term_registry.add_custom_term("term2")
|
|
||||||
term_registry.add_custom_term("term3")
|
|
||||||
transform = build_transformation_from_config(
|
|
||||||
config,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
transformed_annotation = transform(annotation)
|
|
||||||
assert transformed_annotation.tags[0].key == "term1"
|
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
|
||||||
term1: data.Term,
|
|
||||||
):
|
|
||||||
def derivation_func(x: str) -> str:
|
|
||||||
return x + "_derived"
|
|
||||||
|
|
||||||
derivation_registry.register("my_derivation", derivation_func)
|
|
||||||
|
|
||||||
rule = DeriveTagRule(
|
|
||||||
rule_type="derive_tag",
|
|
||||||
source_term_key="term1",
|
|
||||||
derivation_function="my_derivation",
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(
|
|
||||||
rule,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
|
|
||||||
assert len(transformed_annotation.tags) == 2
|
|
||||||
assert transformed_annotation.tags[0].term == term1
|
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
|
||||||
assert transformed_annotation.tags[1].term == term1
|
|
||||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule_keep_source_false(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
|
||||||
term1: data.Term,
|
|
||||||
):
|
|
||||||
def derivation_func(x: str) -> str:
|
|
||||||
return x + "_derived"
|
|
||||||
|
|
||||||
derivation_registry.register("my_derivation", derivation_func)
|
|
||||||
|
|
||||||
rule = DeriveTagRule(
|
|
||||||
rule_type="derive_tag",
|
|
||||||
source_term_key="term1",
|
|
||||||
derivation_function="my_derivation",
|
|
||||||
keep_source=False,
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(
|
|
||||||
rule,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
|
|
||||||
assert len(transformed_annotation.tags) == 1
|
|
||||||
assert transformed_annotation.tags[0].term == term1
|
|
||||||
assert transformed_annotation.tags[0].value == "value1_derived"
|
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule_target_term(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
|
||||||
term1: data.Term,
|
|
||||||
term2: data.Term,
|
|
||||||
):
|
|
||||||
def derivation_func(x: str) -> str:
|
|
||||||
return x + "_derived"
|
|
||||||
|
|
||||||
derivation_registry.register("my_derivation", derivation_func)
|
|
||||||
|
|
||||||
rule = DeriveTagRule(
|
|
||||||
rule_type="derive_tag",
|
|
||||||
source_term_key="term1",
|
|
||||||
derivation_function="my_derivation",
|
|
||||||
target_term_key="term2",
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(
|
|
||||||
rule,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
|
|
||||||
assert len(transformed_annotation.tags) == 2
|
|
||||||
assert transformed_annotation.tags[0].term == term1
|
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
|
||||||
assert transformed_annotation.tags[1].term == term2
|
|
||||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule_import_derivation(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
term1: data.Term,
|
|
||||||
tmp_path: Path,
|
|
||||||
):
|
|
||||||
# Create a dummy derivation function in a temporary file
|
|
||||||
derivation_module_path = (
|
|
||||||
tmp_path / "temp_derivation.py"
|
|
||||||
) # Changed to /tmp since /home/santiago is not writable
|
|
||||||
derivation_module_path.write_text(
|
|
||||||
"""
|
|
||||||
def my_imported_derivation(x: str) -> str:
|
|
||||||
return x + "_imported"
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
# Ensure the temporary file is importable by adding its directory to sys.path
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(0, str(tmp_path))
|
|
||||||
|
|
||||||
rule = DeriveTagRule(
|
|
||||||
rule_type="derive_tag",
|
|
||||||
source_term_key="term1",
|
|
||||||
derivation_function="temp_derivation.my_imported_derivation",
|
|
||||||
import_derivation=True,
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
|
|
||||||
assert len(transformed_annotation.tags) == 2
|
|
||||||
assert transformed_annotation.tags[0].term == term1
|
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
|
||||||
assert transformed_annotation.tags[1].term == term1
|
|
||||||
assert transformed_annotation.tags[1].value == "value1_imported"
|
|
||||||
|
|
||||||
# Clean up the temporary file and sys.path
|
|
||||||
sys.path.remove(str(tmp_path))
|
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
|
|
||||||
rule = DeriveTagRule(
|
|
||||||
rule_type="derive_tag",
|
|
||||||
source_term_key="term1",
|
|
||||||
derivation_function="nonexistent_derivation",
|
|
||||||
)
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_transform_from_rule_invalid_rule_type():
|
|
||||||
class InvalidRule:
|
|
||||||
rule_type = "invalid"
|
|
||||||
|
|
||||||
rule = InvalidRule() # type: ignore
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
build_transform_from_rule(rule) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def test_map_value_rule_target_term(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
term2: data.Term,
|
|
||||||
):
|
|
||||||
rule = MapValueRule(
|
|
||||||
rule_type="map_value",
|
|
||||||
source_term_key="term1",
|
|
||||||
value_mapping={"value1": "value2"},
|
|
||||||
target_term_key="term2",
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
assert transformed_annotation.tags[0].term == term2
|
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_map_value_rule_target_term_none(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
term1: data.Term,
|
|
||||||
):
|
|
||||||
rule = MapValueRule(
|
|
||||||
rule_type="map_value",
|
|
||||||
source_term_key="term1",
|
|
||||||
value_mapping={"value1": "value2"},
|
|
||||||
target_term_key=None,
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
assert transformed_annotation.tags[0].term == term1
|
|
||||||
assert transformed_annotation.tags[0].value == "value2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_derive_tag_rule_target_term_none(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
term_registry: TermRegistry,
|
|
||||||
derivation_registry: DerivationRegistry,
|
|
||||||
term1: data.Term,
|
|
||||||
):
|
|
||||||
def derivation_func(x: str) -> str:
|
|
||||||
return x + "_derived"
|
|
||||||
|
|
||||||
derivation_registry.register("my_derivation", derivation_func)
|
|
||||||
|
|
||||||
rule = DeriveTagRule(
|
|
||||||
rule_type="derive_tag",
|
|
||||||
source_term_key="term1",
|
|
||||||
derivation_function="my_derivation",
|
|
||||||
target_term_key=None,
|
|
||||||
)
|
|
||||||
transform_fn = build_transform_from_rule(
|
|
||||||
rule,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
transformed_annotation = transform_fn(annotation)
|
|
||||||
|
|
||||||
assert len(transformed_annotation.tags) == 2
|
|
||||||
assert transformed_annotation.tags[0].term == term1
|
|
||||||
assert transformed_annotation.tags[0].value == "value1"
|
|
||||||
assert transformed_annotation.tags[1].term == term1
|
|
||||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_transformation_from_config_empty(
|
|
||||||
annotation: data.SoundEventAnnotation,
|
|
||||||
):
|
|
||||||
config = TransformConfig(rules=[])
|
|
||||||
transform = build_transformation_from_config(config)
|
|
||||||
transformed_annotation = transform(annotation)
|
|
||||||
assert transformed_annotation == annotation
|
|
||||||
@ -1,170 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.train.augmentations import (
|
|
||||||
add_echo,
|
|
||||||
mix_audio,
|
|
||||||
)
|
|
||||||
from batdetect2.train.clips import select_subclip
|
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
|
||||||
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
|
|
||||||
|
|
||||||
|
|
||||||
def test_mix_examples(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
recording2 = create_recording()
|
|
||||||
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
|
||||||
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
|
||||||
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
|
||||||
|
|
||||||
example1 = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
example2 = generate_train_example(
|
|
||||||
clip_annotation_2,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
mixed = mix_audio(
|
|
||||||
example1,
|
|
||||||
example2,
|
|
||||||
weight=0.3,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
|
||||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
|
||||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
|
||||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
|
||||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
|
||||||
def test_mix_examples_of_different_durations(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
duration1: float,
|
|
||||||
duration2: float,
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
recording2 = create_recording()
|
|
||||||
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
|
|
||||||
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
|
|
||||||
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
|
||||||
|
|
||||||
example1 = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
example2 = generate_train_example(
|
|
||||||
clip_annotation_2,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
mixed = mix_audio(
|
|
||||||
example1,
|
|
||||||
example2,
|
|
||||||
weight=0.3,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
|
||||||
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
|
||||||
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
|
||||||
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_echo(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
|
|
||||||
original = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
with_echo = add_echo(
|
|
||||||
original,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
delay=0.1,
|
|
||||||
weight=0.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert with_echo.spectrogram.shape == original.spectrogram.shape
|
|
||||||
torch.testing.assert_close(
|
|
||||||
with_echo.size_heatmap,
|
|
||||||
original.size_heatmap,
|
|
||||||
atol=0,
|
|
||||||
rtol=0,
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
with_echo.class_heatmap,
|
|
||||||
original.class_heatmap,
|
|
||||||
atol=0,
|
|
||||||
rtol=0,
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
with_echo.detection_heatmap,
|
|
||||||
original.detection_heatmap,
|
|
||||||
atol=0,
|
|
||||||
rtol=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_selected_random_subclip_has_the_correct_width(
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
):
|
|
||||||
recording1 = create_recording()
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
|
|
||||||
original = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
subclip = select_subclip(
|
|
||||||
original,
|
|
||||||
input_samplerate=256_000,
|
|
||||||
output_samplerate=1000,
|
|
||||||
start=0,
|
|
||||||
duration=0.512,
|
|
||||||
)
|
|
||||||
assert subclip.spectrogram.shape[1] == 512
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.train import generate_train_example
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
ClipLabeller,
|
|
||||||
ClipperProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_clip_size_is_correct(
|
|
||||||
sample_clipper: ClipperProtocol,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
):
|
|
||||||
example = generate_train_example(
|
|
||||||
clip_annotation=clip_annotation,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip, _, _ = sample_clipper(example)
|
|
||||||
assert clip.spectrogram.shape == (1, 128, 256)
|
|
||||||
@ -5,7 +5,6 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||||
from batdetect2.targets.terms import TagInfo
|
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
@ -26,7 +25,8 @@ clip = data.Clip(
|
|||||||
|
|
||||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||||
sample_target_config: TargetConfig,
|
sample_target_config: TargetConfig,
|
||||||
pippip_tag: TagInfo,
|
pippip_tag: data.Tag,
|
||||||
|
bat_tag: data.Tag,
|
||||||
):
|
):
|
||||||
config = sample_target_config.model_copy(
|
config = sample_target_config.model_copy(
|
||||||
update=dict(
|
update=dict(
|
||||||
@ -49,14 +49,14 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
coordinates=[10, 10, 20, 30],
|
coordinates=[10, 10, 20, 30],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
tags=[pippip_tag, bat_tag],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
torch.rand([100, 100]),
|
torch.rand([1, 100, 100]),
|
||||||
min_freq=0,
|
min_freq=0,
|
||||||
max_freq=100,
|
max_freq=100,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
assert size_heatmap[1, 10, 10] == 20
|
assert size_heatmap[1, 10, 10] == 20
|
||||||
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
||||||
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
||||||
assert detection_heatmap[10, 10] == 1.0
|
assert detection_heatmap[0, 10, 10] == 1.0
|
||||||
|
|||||||
@ -4,14 +4,16 @@ import lightning as L
|
|||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.models import build_model
|
||||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
|
model = build_model()
|
||||||
config = FullTrainingConfig()
|
config = FullTrainingConfig()
|
||||||
return build_training_module(config)
|
return build_training_module(model, config=config)
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
@ -32,14 +34,14 @@ def test_can_save_checkpoint(
|
|||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
|
||||||
wav = torch.tensor(sample_audio_loader.load_clip(clip))
|
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
|
||||||
|
|
||||||
spec1 = module.model.preprocessor(wav)
|
spec1 = module.model.preprocessor(wav)
|
||||||
spec2 = recovered.model.preprocessor(wav)
|
spec2 = recovered.model.preprocessor(wav)
|
||||||
|
|
||||||
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||||
|
|
||||||
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
|
output1 = module(spec1.unsqueeze(0))
|
||||||
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
|
output2 = recovered(spec2.unsqueeze(0))
|
||||||
|
|
||||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|||||||
@ -1,230 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.terms import get_term
|
|
||||||
|
|
||||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
|
||||||
from batdetect2.typing import ModelOutput
|
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def build_from_config(
|
|
||||||
create_temp_yaml,
|
|
||||||
):
|
|
||||||
def build(yaml_content):
|
|
||||||
config_path = create_temp_yaml(yaml_content)
|
|
||||||
|
|
||||||
targets_config = load_target_config(config_path, field="targets")
|
|
||||||
preprocessing_config = load_preprocessing_config(
|
|
||||||
config_path,
|
|
||||||
field="preprocessing",
|
|
||||||
)
|
|
||||||
labels_config = load_label_config(config_path, field="labels")
|
|
||||||
postprocessing_config = load_postprocess_config(
|
|
||||||
config_path,
|
|
||||||
field="postprocessing",
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = build_targets(targets_config)
|
|
||||||
preprocessor = build_preprocessor(preprocessing_config)
|
|
||||||
labeller = build_clip_labeler(
|
|
||||||
targets=targets,
|
|
||||||
config=labels_config,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=postprocessing_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return targets, preprocessor, labeller, postprocessor
|
|
||||||
|
|
||||||
return build
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object(
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
build_from_config,
|
|
||||||
recording,
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
labels:
|
|
||||||
targets:
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: bottom-left
|
|
||||||
classes:
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
preprocessing:
|
|
||||||
"""
|
|
||||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
|
||||||
se1 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
|
||||||
|
|
||||||
encoded = generate_train_example(
|
|
||||||
clip_annotation,
|
|
||||||
sample_audio_loader,
|
|
||||||
preprocessor,
|
|
||||||
labeller,
|
|
||||||
)
|
|
||||||
predictions = postprocessor.get_predictions(
|
|
||||||
ModelOutput(
|
|
||||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
|
||||||
0
|
|
||||||
),
|
|
||||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
|
||||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
|
||||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
|
||||||
),
|
|
||||||
[clip],
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
assert isinstance(predictions, data.ClipPrediction)
|
|
||||||
assert len(predictions.sound_events) == 1
|
|
||||||
|
|
||||||
recovered = predictions.sound_events[0]
|
|
||||||
assert recovered.sound_event.geometry is not None
|
|
||||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
|
||||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
|
||||||
recovered.sound_event.geometry.coordinates
|
|
||||||
)
|
|
||||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
|
||||||
geometry.coordinates
|
|
||||||
)
|
|
||||||
|
|
||||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
|
||||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
|
||||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
|
||||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
|
||||||
|
|
||||||
assert len(recovered.tags) == 2
|
|
||||||
|
|
||||||
predicted_species_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_species_tag is not None
|
|
||||||
assert predicted_species_tag.score == 1
|
|
||||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
|
||||||
|
|
||||||
predicted_order_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_order_tag is not None
|
|
||||||
assert predicted_order_tag.score == 1
|
|
||||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
build_from_config,
|
|
||||||
recording,
|
|
||||||
):
|
|
||||||
yaml_content = """
|
|
||||||
labels:
|
|
||||||
targets:
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: bottom-left
|
|
||||||
classes:
|
|
||||||
classes:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
- name: myomyo
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
generic_class:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
preprocessing:
|
|
||||||
"""
|
|
||||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
|
||||||
se1 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
|
||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
|
||||||
|
|
||||||
encoded = generate_train_example(
|
|
||||||
clip_annotation,
|
|
||||||
sample_audio_loader,
|
|
||||||
preprocessor,
|
|
||||||
labeller,
|
|
||||||
)
|
|
||||||
predictions = postprocessor.get_predictions(
|
|
||||||
ModelOutput(
|
|
||||||
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
|
||||||
0
|
|
||||||
),
|
|
||||||
size_preds=encoded.size_heatmap.unsqueeze(0),
|
|
||||||
class_probs=encoded.class_heatmap.unsqueeze(0),
|
|
||||||
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
|
||||||
),
|
|
||||||
[clip],
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
assert isinstance(predictions, data.ClipPrediction)
|
|
||||||
assert len(predictions.sound_events) == 1
|
|
||||||
|
|
||||||
recovered = predictions.sound_events[0]
|
|
||||||
assert recovered.sound_event.geometry is not None
|
|
||||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
|
||||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
|
||||||
recovered.sound_event.geometry.coordinates
|
|
||||||
)
|
|
||||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
|
||||||
geometry.coordinates
|
|
||||||
)
|
|
||||||
|
|
||||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
|
||||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
|
||||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
|
||||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
|
||||||
|
|
||||||
assert len(recovered.tags) == 2
|
|
||||||
|
|
||||||
predicted_species_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_species_tag is not None
|
|
||||||
assert predicted_species_tag.score == 1
|
|
||||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
|
||||||
|
|
||||||
predicted_order_tag = next(
|
|
||||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert predicted_order_tag is not None
|
|
||||||
assert predicted_order_tag.score == 1
|
|
||||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
|
||||||
Loading…
Reference in New Issue
Block a user