mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Compare commits
No commits in common. "cd4955d4f32a44dce8cd7632ac2daa5bd3c60e31" and "db2ad1174363f1c228829c7a139e229bf9171c42" have entirely different histories.
cd4955d4f3
...
db2ad11743
@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
|
|||||||
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
||||||
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
||||||
|
|
||||||
**Example YAML Configuration (e.g., inside a filter rule):**
|
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# ... inside a filtering configuration section ...
|
# ... inside a filtering configuration section ...
|
||||||
|
|||||||
@ -1,47 +1,44 @@
|
|||||||
targets:
|
targets:
|
||||||
detection_target:
|
classes:
|
||||||
name: bat
|
classes:
|
||||||
match_if:
|
- name: myomys
|
||||||
name: all_of
|
tags:
|
||||||
conditions:
|
- value: Myotis mystacinus
|
||||||
- name: has_tag
|
- name: pippip
|
||||||
tag: { key: event, value: Echolocation }
|
tags:
|
||||||
- name: not
|
- value: Pipistrellus pipistrellus
|
||||||
condition:
|
- name: eptser
|
||||||
name: has_tag
|
tags:
|
||||||
tag: { key: class, value: Unknown }
|
- value: Eptesicus serotinus
|
||||||
assign_tags:
|
- name: rhifer
|
||||||
|
tags:
|
||||||
|
- value: Rhinolophus ferrumequinum
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
generic_class:
|
||||||
- key: class
|
- key: class
|
||||||
value: Bat
|
value: Bat
|
||||||
|
|
||||||
classification_targets:
|
filtering:
|
||||||
- name: myomys
|
rules:
|
||||||
tags:
|
- match_type: all
|
||||||
- key: class
|
tags:
|
||||||
value: Myotis mystacinus
|
- key: event
|
||||||
- name: pippip
|
value: Echolocation
|
||||||
tags:
|
- match_type: exclude
|
||||||
- key: class
|
tags:
|
||||||
value: Pipistrellus pipistrellus
|
- key: class
|
||||||
- name: eptser
|
value: Unknown
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Eptesicus serotinus
|
|
||||||
- name: rhifer
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Rhinolophus ferrumequinum
|
|
||||||
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
audio:
|
audio:
|
||||||
samplerate: 256000
|
|
||||||
resample:
|
resample:
|
||||||
enabled: True
|
samplerate: 256000
|
||||||
method: "poly"
|
method: "poly"
|
||||||
|
scale: false
|
||||||
|
center: true
|
||||||
|
duration: null
|
||||||
|
|
||||||
spectrogram:
|
spectrogram:
|
||||||
stft:
|
stft:
|
||||||
@ -51,66 +48,66 @@ preprocess:
|
|||||||
frequencies:
|
frequencies:
|
||||||
max_freq: 120000
|
max_freq: 120000
|
||||||
min_freq: 10000
|
min_freq: 10000
|
||||||
|
pcen:
|
||||||
|
time_constant: 0.1
|
||||||
|
gain: 0.98
|
||||||
|
bias: 2
|
||||||
|
power: 0.5
|
||||||
|
scale: "amplitude"
|
||||||
size:
|
size:
|
||||||
height: 128
|
height: 128
|
||||||
resize_factor: 0.5
|
resize_factor: 0.5
|
||||||
transforms:
|
spectral_mean_substraction: true
|
||||||
- name: pcen
|
peak_normalize: false
|
||||||
time_constant: 0.1
|
|
||||||
gain: 0.98
|
|
||||||
bias: 2
|
|
||||||
power: 0.5
|
|
||||||
- name: spectral_mean_substraction
|
|
||||||
|
|
||||||
postprocess:
|
postprocess:
|
||||||
nms_kernel_size: 9
|
nms_kernel_size: 9
|
||||||
detection_threshold: 0.01
|
detection_threshold: 0.01
|
||||||
|
min_freq: 10000
|
||||||
|
max_freq: 120000
|
||||||
top_k_per_sec: 200
|
top_k_per_sec: 200
|
||||||
|
|
||||||
|
labels:
|
||||||
|
sigma: 3
|
||||||
|
|
||||||
model:
|
model:
|
||||||
input_height: 128
|
input_height: 128
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
encoder:
|
encoder:
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvDown
|
- block_type: FreqCoordConvDown
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- name: FreqCoordConvDown
|
- block_type: FreqCoordConvDown
|
||||||
out_channels: 64
|
out_channels: 64
|
||||||
- name: LayerGroup
|
- block_type: LayerGroup
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvDown
|
- block_type: FreqCoordConvDown
|
||||||
out_channels: 128
|
out_channels: 128
|
||||||
- name: ConvBlock
|
- block_type: ConvBlock
|
||||||
out_channels: 256
|
out_channels: 256
|
||||||
bottleneck:
|
bottleneck:
|
||||||
channels: 256
|
channels: 256
|
||||||
layers:
|
layers:
|
||||||
- name: SelfAttention
|
- block_type: SelfAttention
|
||||||
attention_channels: 256
|
attention_channels: 256
|
||||||
decoder:
|
decoder:
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
out_channels: 64
|
out_channels: 64
|
||||||
- name: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- name: LayerGroup
|
- block_type: LayerGroup
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- name: ConvBlock
|
- block_type: ConvBlock
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
|
|
||||||
train:
|
train:
|
||||||
learning_rate: 0.001
|
learning_rate: 0.001
|
||||||
t_max: 100
|
t_max: 100
|
||||||
|
|
||||||
labels:
|
|
||||||
sigma: 3
|
|
||||||
|
|
||||||
trainer:
|
|
||||||
max_epochs: 40
|
|
||||||
|
|
||||||
dataloaders:
|
dataloaders:
|
||||||
train:
|
train:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
@ -118,7 +115,7 @@ train:
|
|||||||
shuffle: True
|
shuffle: True
|
||||||
|
|
||||||
val:
|
val:
|
||||||
batch_size: 1
|
batch_size: 8
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
@ -137,34 +134,32 @@ train:
|
|||||||
|
|
||||||
logger:
|
logger:
|
||||||
logger_type: csv
|
logger_type: csv
|
||||||
# save_dir: outputs/log/
|
save_dir: outputs/log/
|
||||||
# name: logs
|
name: logs
|
||||||
|
|
||||||
augmentations:
|
augmentations:
|
||||||
enabled: true
|
steps:
|
||||||
audio:
|
- augmentation_type: mix_audio
|
||||||
- name: mix_audio
|
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
min_weight: 0.3
|
min_weight: 0.3
|
||||||
max_weight: 0.7
|
max_weight: 0.7
|
||||||
- name: add_echo
|
- augmentation_type: add_echo
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
max_delay: 0.005
|
max_delay: 0.005
|
||||||
min_weight: 0.0
|
min_weight: 0.0
|
||||||
max_weight: 1.0
|
max_weight: 1.0
|
||||||
spectrogram:
|
- augmentation_type: scale_volume
|
||||||
- name: scale_volume
|
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
min_scaling: 0.0
|
min_scaling: 0.0
|
||||||
max_scaling: 2.0
|
max_scaling: 2.0
|
||||||
- name: warp
|
- augmentation_type: warp
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
delta: 0.04
|
delta: 0.04
|
||||||
- name: mask_time
|
- augmentation_type: mask_time
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
max_perc: 0.05
|
max_perc: 0.05
|
||||||
max_masks: 3
|
max_masks: 3
|
||||||
- name: mask_freq
|
- augmentation_type: mask_freq
|
||||||
probability: 0.2
|
probability: 0.2
|
||||||
max_perc: 0.10
|
max_perc: 0.10
|
||||||
max_masks: 3
|
max_masks: 3
|
||||||
|
|||||||
12
justfile
12
justfile
@ -92,11 +92,19 @@ clean-build:
|
|||||||
clean: clean-build clean-pyc clean-test clean-docs
|
clean: clean-build clean-pyc clean-test clean-docs
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
|
# Preprocess example data.
|
||||||
|
example-preprocess OPTIONS="":
|
||||||
|
batdetect2 preprocess \
|
||||||
|
--base-dir . \
|
||||||
|
--dataset-field datasets.train \
|
||||||
|
--config example_data/config.yaml \
|
||||||
|
{{OPTIONS}} \
|
||||||
|
example_data/config.yaml example_data/preprocessed
|
||||||
|
|
||||||
# Train on example data.
|
# Train on example data.
|
||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
batdetect2 train \
|
batdetect2 train \
|
||||||
--val-dataset example_data/dataset.yaml \
|
--val-dir example_data/preprocessed \
|
||||||
--config example_data/config.yaml \
|
--config example_data/config.yaml \
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
example_data/dataset.yaml
|
example_data/preprocessed
|
||||||
|
|||||||
@ -17,7 +17,7 @@ dependencies = [
|
|||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
"soundevent[audio,geometry,plot]>=2.8.1",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.evaluate import evaluate_command
|
from batdetect2.cli.preprocess import preprocess
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -9,7 +9,7 @@ __all__ = [
|
|||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
"evaluate_command",
|
"preprocess",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,63 +0,0 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.evaluate.evaluate import evaluate
|
|
||||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
|
||||||
|
|
||||||
__all__ = ["evaluate_command"]
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate")
|
|
||||||
@click.argument("model-path", type=click.Path(exists=True))
|
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
|
||||||
@click.option("--output-dir", type=click.Path())
|
|
||||||
@click.option("--workers", type=int)
|
|
||||||
@click.option(
|
|
||||||
"-v",
|
|
||||||
"--verbose",
|
|
||||||
count=True,
|
|
||||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
|
||||||
)
|
|
||||||
def evaluate_command(
|
|
||||||
model_path: Path,
|
|
||||||
test_dataset: Path,
|
|
||||||
output_dir: Optional[Path] = None,
|
|
||||||
workers: Optional[int] = None,
|
|
||||||
verbose: int = 0,
|
|
||||||
):
|
|
||||||
logger.remove()
|
|
||||||
if verbose == 0:
|
|
||||||
log_level = "WARNING"
|
|
||||||
elif verbose == 1:
|
|
||||||
log_level = "INFO"
|
|
||||||
else:
|
|
||||||
log_level = "DEBUG"
|
|
||||||
logger.add(sys.stderr, level=log_level)
|
|
||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
|
||||||
|
|
||||||
test_annotations = load_dataset_from_config(test_dataset)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded {num_annotations} test examples",
|
|
||||||
num_annotations=len(test_annotations),
|
|
||||||
)
|
|
||||||
|
|
||||||
model, train_config = load_model_from_checkpoint(model_path)
|
|
||||||
|
|
||||||
df, results = evaluate(
|
|
||||||
model,
|
|
||||||
test_annotations,
|
|
||||||
config=train_config,
|
|
||||||
num_workers=workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(results)
|
|
||||||
|
|
||||||
if output_dir:
|
|
||||||
df.to_csv(output_dir / "results.csv")
|
|
||||||
142
src/batdetect2/cli/preprocess.py
Normal file
142
src/batdetect2/cli/preprocess.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
import yaml
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
TrainPreprocessConfig,
|
||||||
|
load_train_preprocessing_config,
|
||||||
|
preprocess_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["preprocess"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument(
|
||||||
|
"dataset_config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"output",
|
||||||
|
type=click.Path(),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dataset-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Specifies the key to access the dataset information within the "
|
||||||
|
"dataset configuration file, if the information is nested inside a "
|
||||||
|
"dictionary. If the dataset information is at the top level of the "
|
||||||
|
"config file, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"The main directory where your audio recordings and annotation "
|
||||||
|
"files are stored. This helps the program find your data, "
|
||||||
|
"especially if the paths in your dataset configuration file "
|
||||||
|
"are relative."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the configuration file. This file tells "
|
||||||
|
"the program how to prepare your audio data before training, such "
|
||||||
|
"as resampling or applying filters."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the preprocessing settings are inside a nested dictionary "
|
||||||
|
"within the preprocessing configuration file, specify the key "
|
||||||
|
"here to access them. If the preprocessing settings are at the "
|
||||||
|
"top level, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
help=(
|
||||||
|
"The maximum number of computer cores to use when processing "
|
||||||
|
"your audio data. Using more cores can speed up the preprocessing, "
|
||||||
|
"but don't use more than your computer has available. By default, "
|
||||||
|
"the program will use all available cores."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
count=True,
|
||||||
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
|
)
|
||||||
|
def preprocess(
|
||||||
|
dataset_config: Path,
|
||||||
|
output: Path,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
config: Optional[Path] = None,
|
||||||
|
config_field: Optional[str] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
dataset_field: Optional[str] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
):
|
||||||
|
logger.remove()
|
||||||
|
if verbose == 0:
|
||||||
|
log_level = "WARNING"
|
||||||
|
elif verbose == 1:
|
||||||
|
log_level = "INFO"
|
||||||
|
else:
|
||||||
|
log_level = "DEBUG"
|
||||||
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
|
logger.info("Starting preprocessing.")
|
||||||
|
|
||||||
|
output = Path(output)
|
||||||
|
logger.info("Will save outputs to {output}", output=output)
|
||||||
|
|
||||||
|
base_dir = base_dir or Path.cwd()
|
||||||
|
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
logger.info(
|
||||||
|
"Loading preprocessing config from: {config}", config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
conf = (
|
||||||
|
load_train_preprocessing_config(config, field=config_field)
|
||||||
|
if config is not None
|
||||||
|
else TrainPreprocessConfig()
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Preprocessing config:\n{conf}",
|
||||||
|
conf=yaml.dump(conf.model_dump()),
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset_from_config(
|
||||||
|
dataset_config,
|
||||||
|
field=dataset_field,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||||
|
num_examples=len(dataset),
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocess_dataset(
|
||||||
|
dataset,
|
||||||
|
conf,
|
||||||
|
output=output,
|
||||||
|
max_workers=num_workers,
|
||||||
|
)
|
||||||
@ -20,8 +20,6 @@ __all__ = ["train_command"]
|
|||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
|
||||||
@click.option("--log-dir", type=click.Path(exists=True))
|
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@ -36,8 +34,6 @@ def train_command(
|
|||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
val_dataset: Optional[Path] = None,
|
val_dataset: Optional[Path] = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Optional[Path] = None,
|
||||||
ckpt_dir: Optional[Path] = None,
|
|
||||||
log_dir: Optional[Path] = None,
|
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -87,6 +83,4 @@ def train_command(
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
log_dir=log_dir,
|
|
||||||
checkpoint_dir=ckpt_dir,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,61 +0,0 @@
|
|||||||
from typing import Generic, Protocol, Type, TypeVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Registry",
|
|
||||||
]
|
|
||||||
|
|
||||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
|
||||||
T_Type = TypeVar("T_Type", covariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
class LogicProtocol(Generic[T_Config, T_Type], Protocol):
|
|
||||||
"""A generic protocol for the logic classes (conditions or transforms)."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: T_Config) -> T_Type: ...
|
|
||||||
|
|
||||||
|
|
||||||
T_Proto = TypeVar("T_Proto", bound=LogicProtocol)
|
|
||||||
|
|
||||||
|
|
||||||
class Registry(Generic[T_Type]):
|
|
||||||
"""A generic class to create and manage a registry of items."""
|
|
||||||
|
|
||||||
def __init__(self, name: str):
|
|
||||||
self._name = name
|
|
||||||
self._registry = {}
|
|
||||||
|
|
||||||
def register(self, config_cls: Type[T_Config]):
|
|
||||||
"""A decorator factory to register a new item."""
|
|
||||||
fields = config_cls.model_fields
|
|
||||||
|
|
||||||
if "name" not in fields:
|
|
||||||
raise ValueError("Configuration object must have a 'name' field.")
|
|
||||||
|
|
||||||
name = fields["name"].default
|
|
||||||
|
|
||||||
if not isinstance(name, str):
|
|
||||||
raise ValueError("'name' field must be a string literal.")
|
|
||||||
|
|
||||||
def decorator(logic_cls: Type[T_Proto]) -> Type[T_Proto]:
|
|
||||||
self._registry[name] = logic_cls
|
|
||||||
return logic_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def build(self, config: BaseModel) -> T_Type:
|
|
||||||
"""Builds a logic instance from a config object."""
|
|
||||||
|
|
||||||
name = getattr(config, "name") # noqa: B009
|
|
||||||
|
|
||||||
if name is None:
|
|
||||||
raise ValueError("Config does not have a name field")
|
|
||||||
|
|
||||||
if name not in self._registry:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"No {self._name} with name '{name}' is registered."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._registry[name].from_config(config)
|
|
||||||
@ -14,9 +14,8 @@ format-specific loading function to retrieve the annotations as a standard
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.annotations.aoef import (
|
from batdetect2.data.annotations.aoef import (
|
||||||
@ -43,13 +42,10 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
AnnotationFormats = Annotated[
|
AnnotationFormats = Union[
|
||||||
Union[
|
BatDetect2MergedAnnotations,
|
||||||
BatDetect2MergedAnnotations,
|
BatDetect2FilesAnnotations,
|
||||||
BatDetect2FilesAnnotations,
|
AOEFAnnotations,
|
||||||
AOEFAnnotations,
|
|
||||||
],
|
|
||||||
Field(discriminator="format"),
|
|
||||||
]
|
]
|
||||||
"""Type Alias representing all supported data source configurations.
|
"""Type Alias representing all supported data source configurations.
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,8 @@ from typing import Callable, List, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import get_term_from_key
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@ -90,15 +92,15 @@ def annotation_to_sound_event(
|
|||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
key=label_key, # type: ignore
|
term=get_term_from_key(label_key),
|
||||||
value=annotation.label,
|
value=annotation.label,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
key=event_key, # type: ignore
|
term=get_term_from_key(event_key),
|
||||||
value=annotation.event,
|
value=annotation.event,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
key=individual_key, # type: ignore
|
term=get_term_from_key(individual_key),
|
||||||
value=str(annotation.individual),
|
value=str(annotation.individual),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -123,7 +125,7 @@ def file_annotation_to_clip(
|
|||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
key=label_key, # type: ignore
|
term=get_term_from_key(label_key),
|
||||||
value=file_annotation.label,
|
value=file_annotation.label,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -155,8 +157,7 @@ def file_annotation_to_clip_annotation(
|
|||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
key=label_key, # type: ignore
|
term=get_term_from_key(label_key), value=file_annotation.label
|
||||||
value=file_annotation.label,
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
sound_events=[
|
sound_events=[
|
||||||
|
|||||||
@ -1,287 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Annotated, List, Literal, Sequence, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
|
||||||
from batdetect2.data._core import Registry
|
|
||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
|
||||||
|
|
||||||
_conditions: Registry[SoundEventCondition] = Registry("condition")
|
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_tag"] = "has_tag"
|
|
||||||
tag: data.Tag
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(HasTagConfig)
|
|
||||||
class HasTag:
|
|
||||||
def __init__(self, tag: data.Tag):
|
|
||||||
self.tag = tag
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return self.tag in sound_event_annotation.tags
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: HasTagConfig):
|
|
||||||
return cls(tag=config.tag)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTagsConfig(BaseConfig):
|
|
||||||
name: Literal["has_all_tags"] = "has_all_tags"
|
|
||||||
tags: List[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(HasAllTagsConfig)
|
|
||||||
class HasAllTags:
|
|
||||||
def __init__(self, tags: List[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.tags = set(tags)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return self.tags.issubset(sound_event_annotation.tags)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: HasAllTagsConfig):
|
|
||||||
return cls(tags=config.tags)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_any_tag"] = "has_any_tag"
|
|
||||||
tags: List[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(HasAnyTagConfig)
|
|
||||||
class HasAnyTag:
|
|
||||||
def __init__(self, tags: List[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.tags = set(tags)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: HasAnyTagConfig):
|
|
||||||
return cls(tags=config.tags)
|
|
||||||
|
|
||||||
|
|
||||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
|
||||||
|
|
||||||
|
|
||||||
class DurationConfig(BaseConfig):
|
|
||||||
name: Literal["duration"] = "duration"
|
|
||||||
operator: Operator
|
|
||||||
seconds: float
|
|
||||||
|
|
||||||
|
|
||||||
def _build_comparator(
|
|
||||||
operator: Operator, value: float
|
|
||||||
) -> Callable[[float], bool]:
|
|
||||||
if operator == "gt":
|
|
||||||
return lambda x: x > value
|
|
||||||
|
|
||||||
if operator == "gte":
|
|
||||||
return lambda x: x >= value
|
|
||||||
|
|
||||||
if operator == "lt":
|
|
||||||
return lambda x: x < value
|
|
||||||
|
|
||||||
if operator == "lte":
|
|
||||||
return lambda x: x <= value
|
|
||||||
|
|
||||||
if operator == "eq":
|
|
||||||
return lambda x: x == value
|
|
||||||
|
|
||||||
raise ValueError(f"Invalid operator {operator}")
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(DurationConfig)
|
|
||||||
class Duration:
|
|
||||||
def __init__(self, operator: Operator, seconds: float):
|
|
||||||
self.operator = operator
|
|
||||||
self.seconds = seconds
|
|
||||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
|
||||||
duration = end_time - start_time
|
|
||||||
|
|
||||||
return self._comparator(duration)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: DurationConfig):
|
|
||||||
return cls(operator=config.operator, seconds=config.seconds)
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
|
||||||
name: Literal["frequency"] = "frequency"
|
|
||||||
boundary: Literal["low", "high"]
|
|
||||||
operator: Operator
|
|
||||||
hertz: float
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(FrequencyConfig)
|
|
||||||
class Frequency:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
operator: Operator,
|
|
||||||
boundary: Literal["low", "high"],
|
|
||||||
hertz: float,
|
|
||||||
):
|
|
||||||
self.operator = operator
|
|
||||||
self.hertz = hertz
|
|
||||||
self.boundary = boundary
|
|
||||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Automatically false if geometry does not have a frequency range
|
|
||||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
|
||||||
|
|
||||||
if self.boundary == "low":
|
|
||||||
return self._comparator(low_freq)
|
|
||||||
|
|
||||||
return self._comparator(high_freq)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: FrequencyConfig):
|
|
||||||
return cls(
|
|
||||||
operator=config.operator,
|
|
||||||
boundary=config.boundary,
|
|
||||||
hertz=config.hertz,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AllOfConfig(BaseConfig):
|
|
||||||
name: Literal["all_of"] = "all_of"
|
|
||||||
conditions: Sequence["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(AllOfConfig)
|
|
||||||
class AllOf:
|
|
||||||
def __init__(self, conditions: List[SoundEventCondition]):
|
|
||||||
self.conditions = conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return all(c(sound_event_annotation) for c in self.conditions)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: AllOfConfig):
|
|
||||||
conditions = [
|
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
|
||||||
]
|
|
||||||
return cls(conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class AnyOfConfig(BaseConfig):
|
|
||||||
name: Literal["any_of"] = "any_of"
|
|
||||||
conditions: List["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(AnyOfConfig)
|
|
||||||
class AnyOf:
|
|
||||||
def __init__(self, conditions: List[SoundEventCondition]):
|
|
||||||
self.conditions = conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return any(c(sound_event_annotation) for c in self.conditions)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: AnyOfConfig):
|
|
||||||
conditions = [
|
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
|
||||||
]
|
|
||||||
return cls(conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class NotConfig(BaseConfig):
|
|
||||||
name: Literal["not"] = "not"
|
|
||||||
condition: "SoundEventConditionConfig"
|
|
||||||
|
|
||||||
|
|
||||||
@_conditions.register(NotConfig)
|
|
||||||
class Not:
|
|
||||||
def __init__(self, condition: SoundEventCondition):
|
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return not self.condition(sound_event_annotation)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: NotConfig):
|
|
||||||
condition = build_sound_event_condition(config.condition)
|
|
||||||
return cls(condition)
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
HasTagConfig,
|
|
||||||
HasAllTagsConfig,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
DurationConfig,
|
|
||||||
FrequencyConfig,
|
|
||||||
AllOfConfig,
|
|
||||||
AnyOfConfig,
|
|
||||||
NotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_condition(
|
|
||||||
config: SoundEventConditionConfig,
|
|
||||||
) -> SoundEventCondition:
|
|
||||||
return _conditions.build(config)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_clip_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
condition: SoundEventCondition,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=[
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if condition(sound_event)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@ -19,7 +19,7 @@ The core components are:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import Annotated, List, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -31,17 +31,6 @@ from batdetect2.data.annotations import (
|
|||||||
AnnotationFormats,
|
AnnotationFormats,
|
||||||
load_annotated_dataset,
|
load_annotated_dataset,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
SoundEventConditionConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
filter_clip_annotation,
|
|
||||||
)
|
|
||||||
from batdetect2.data.transforms import (
|
|
||||||
ApplyAll,
|
|
||||||
SoundEventTransformConfig,
|
|
||||||
build_sound_event_transform,
|
|
||||||
transform_clip_annotation,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.terms import data_source
|
from batdetect2.targets.terms import data_source
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -63,68 +52,79 @@ sources.
|
|||||||
|
|
||||||
|
|
||||||
class DatasetConfig(BaseConfig):
|
class DatasetConfig(BaseConfig):
|
||||||
"""Configuration model defining the structure of a BatDetect2 dataset."""
|
"""Configuration model defining the structure of a BatDetect2 dataset.
|
||||||
|
|
||||||
|
This class is typically loaded from a YAML file and describes the components
|
||||||
|
of the dataset, including metadata and a list of data sources.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
|
||||||
|
description : str
|
||||||
|
A longer description of the dataset's contents, origin, purpose, etc.
|
||||||
|
sources : List[AnnotationFormats]
|
||||||
|
A list defining the different data sources contributing to this
|
||||||
|
dataset. Each item in the list must conform to one of the Pydantic
|
||||||
|
models defined in the `AnnotationFormats` type union. The specific
|
||||||
|
model used for each source is determined by the mandatory `format`
|
||||||
|
field within the source's configuration, allowing BatDetect2 to use the
|
||||||
|
correct parser for different annotation styles.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
sources: List[AnnotationFormats]
|
sources: List[
|
||||||
|
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||||
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
]
|
||||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
|
||||||
default_factory=list
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
config: DatasetConfig,
|
dataset: DatasetConfig,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
"""Load all clip annotations from the sources defined in a DatasetConfig.
|
||||||
|
|
||||||
|
Iterates through each data source specified in the `dataset_config`,
|
||||||
|
delegates the loading and parsing of that source's annotations to
|
||||||
|
`batdetect2.data.annotations.load_annotated_dataset` (which handles
|
||||||
|
different data formats), and aggregates all resulting `ClipAnnotation`
|
||||||
|
objects into a single flat list.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dataset_config : DatasetConfig
|
||||||
|
The configuration object describing the dataset and its sources.
|
||||||
|
base_dir : Path, optional
|
||||||
|
An optional base directory path. If provided, relative paths for
|
||||||
|
metadata files or data directories within the `dataset_config`'s
|
||||||
|
sources might be resolved relative to this directory. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dataset (List[data.ClipAnnotation])
|
||||||
|
A flat list containing all loaded `ClipAnnotation` metadata objects
|
||||||
|
from all specified sources.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
Exception
|
||||||
|
Can raise various exceptions during the delegated loading process
|
||||||
|
(`load_annotated_dataset`) if files are not found, cannot be parsed
|
||||||
|
according to the specified format, or other I/O errors occur.
|
||||||
|
"""
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
|
for source in dataset.sources:
|
||||||
condition = (
|
|
||||||
build_sound_event_condition(config.sound_event_filter)
|
|
||||||
if config.sound_event_filter is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
transform = (
|
|
||||||
ApplyAll(
|
|
||||||
[
|
|
||||||
build_sound_event_transform(step)
|
|
||||||
for step in config.sound_event_transforms
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if config.sound_event_transforms
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
for source in config.sources:
|
|
||||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||||
num_examples=len(annotated_source.clip_annotations),
|
num_examples=len(annotated_source.clip_annotations),
|
||||||
source_name=source.name,
|
source_name=source.name,
|
||||||
)
|
)
|
||||||
|
clip_annotations.extend(
|
||||||
for clip_annotation in annotated_source.clip_annotations:
|
insert_source_tag(clip_annotation, source)
|
||||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
for clip_annotation in annotated_source.clip_annotations
|
||||||
|
)
|
||||||
if condition is not None:
|
|
||||||
clip_annotation = filter_clip_annotation(
|
|
||||||
clip_annotation,
|
|
||||||
condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
if transform is not None:
|
|
||||||
clip_annotation = transform_clip_annotation(
|
|
||||||
clip_annotation,
|
|
||||||
transform,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_annotations.append(clip_annotation)
|
|
||||||
|
|
||||||
return clip_annotations
|
return clip_annotations
|
||||||
|
|
||||||
|
|
||||||
@ -161,6 +161,7 @@ def insert_source_tag(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add documentation
|
||||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||||
|
|
||||||
|
|||||||
@ -10,8 +10,16 @@ from batdetect2.typing.targets import TargetProtocol
|
|||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
apply_filter: bool = True,
|
||||||
|
apply_transform: bool = True,
|
||||||
|
exclude_generic: bool = True,
|
||||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||||
"""Iterate over sound events in a dataset.
|
"""Iterate over sound events in a dataset, applying filtering and
|
||||||
|
transformations.
|
||||||
|
|
||||||
|
This generator function processes sound event annotations from a given
|
||||||
|
dataset, allowing for optional filtering, transformation, and exclusion of
|
||||||
|
unclassifiable (generic) events based on the provided target definitions.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -21,6 +29,18 @@ def iterate_over_sound_events(
|
|||||||
targets : TargetProtocol
|
targets : TargetProtocol
|
||||||
An object implementing the `TargetProtocol`, which provides methods
|
An object implementing the `TargetProtocol`, which provides methods
|
||||||
for filtering, transforming, and encoding sound events.
|
for filtering, transforming, and encoding sound events.
|
||||||
|
apply_filter : bool, optional
|
||||||
|
If True, sound events will be filtered using `targets.filter()`.
|
||||||
|
Only events for which `targets.filter()` returns True will be yielded.
|
||||||
|
Defaults to True.
|
||||||
|
apply_transform : bool, optional
|
||||||
|
If True, sound events will be transformed using `targets.transform()`
|
||||||
|
before being yielded. Defaults to True.
|
||||||
|
exclude_generic : bool, optional
|
||||||
|
If True, sound events that result in a `None` class name after
|
||||||
|
`targets.encode()` will be excluded. This is typically used to
|
||||||
|
filter out events that cannot be mapped to a specific target class.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
Yields
|
Yields
|
||||||
------
|
------
|
||||||
@ -43,9 +63,17 @@ def iterate_over_sound_events(
|
|||||||
"""
|
"""
|
||||||
for clip_annotation in dataset:
|
for clip_annotation in dataset:
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
if not targets.filter(sound_event_annotation):
|
if apply_filter:
|
||||||
continue
|
if not targets.filter(sound_event_annotation):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if apply_transform:
|
||||||
|
sound_event_annotation = targets.transform(
|
||||||
|
sound_event_annotation
|
||||||
|
)
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event_annotation)
|
class_name = targets.encode_class(sound_event_annotation)
|
||||||
|
if class_name is None and exclude_generic:
|
||||||
|
continue
|
||||||
|
|
||||||
yield class_name, sound_event_annotation
|
yield class_name, sound_event_annotation
|
||||||
|
|||||||
@ -1,250 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
|
||||||
from batdetect2.data._core import Registry
|
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
SoundEventCondition,
|
|
||||||
SoundEventConditionConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
SoundEventTransform = Callable[
|
|
||||||
[data.SoundEventAnnotation],
|
|
||||||
data.SoundEventAnnotation,
|
|
||||||
]
|
|
||||||
|
|
||||||
_transforms: Registry[SoundEventTransform] = Registry("transform")
|
|
||||||
|
|
||||||
|
|
||||||
class SetFrequencyBoundConfig(BaseConfig):
|
|
||||||
name: Literal["set_frequency"] = "set_frequency"
|
|
||||||
boundary: Literal["low", "high"] = "low"
|
|
||||||
hertz: float
|
|
||||||
|
|
||||||
|
|
||||||
@_transforms.register(SetFrequencyBoundConfig)
|
|
||||||
class SetFrequencyBound:
|
|
||||||
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
|
|
||||||
self.hertz = hertz
|
|
||||||
self.boundary = boundary
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
sound_event = sound_event_annotation.sound_event
|
|
||||||
geometry = sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
if not isinstance(geometry, data.BoundingBox):
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = geometry.coordinates
|
|
||||||
|
|
||||||
if self.boundary == "low":
|
|
||||||
low_freq = self.hertz
|
|
||||||
high_freq = max(high_freq, low_freq)
|
|
||||||
|
|
||||||
elif self.boundary == "high":
|
|
||||||
high_freq = self.hertz
|
|
||||||
low_freq = min(high_freq, low_freq)
|
|
||||||
|
|
||||||
geometry = data.BoundingBox(
|
|
||||||
coordinates=[start_time, low_freq, end_time, high_freq],
|
|
||||||
)
|
|
||||||
|
|
||||||
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
|
|
||||||
return sound_event_annotation.model_copy(
|
|
||||||
update=dict(sound_event=sound_event)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: SetFrequencyBoundConfig):
|
|
||||||
return cls(hertz=config.hertz, boundary=config.boundary)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyIfConfig(BaseConfig):
|
|
||||||
name: Literal["apply_if"] = "apply_if"
|
|
||||||
transform: "SoundEventTransformConfig"
|
|
||||||
condition: SoundEventConditionConfig
|
|
||||||
|
|
||||||
|
|
||||||
@_transforms.register(ApplyIfConfig)
|
|
||||||
class ApplyIf:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
condition: SoundEventCondition,
|
|
||||||
transform: SoundEventTransform,
|
|
||||||
):
|
|
||||||
self.condition = condition
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
if not self.condition(sound_event_annotation):
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
return self.transform(sound_event_annotation)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: ApplyIfConfig):
|
|
||||||
transform = build_sound_event_transform(config.transform)
|
|
||||||
condition = build_sound_event_condition(config.condition)
|
|
||||||
return cls(condition=condition, transform=transform)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceTagConfig(BaseConfig):
|
|
||||||
name: Literal["replace_tag"] = "replace_tag"
|
|
||||||
original: data.Tag
|
|
||||||
replacement: data.Tag
|
|
||||||
|
|
||||||
|
|
||||||
@_transforms.register(ReplaceTagConfig)
|
|
||||||
class ReplaceTag:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
original: data.Tag,
|
|
||||||
replacement: data.Tag,
|
|
||||||
):
|
|
||||||
self.original = original
|
|
||||||
self.replacement = replacement
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
tags = []
|
|
||||||
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
if tag == self.original:
|
|
||||||
tags.append(self.replacement)
|
|
||||||
else:
|
|
||||||
tags.append(tag)
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: ReplaceTagConfig):
|
|
||||||
return cls(original=config.original, replacement=config.replacement)
|
|
||||||
|
|
||||||
|
|
||||||
class MapTagValueConfig(BaseConfig):
|
|
||||||
name: Literal["map_tag_value"] = "map_tag_value"
|
|
||||||
tag_key: str
|
|
||||||
value_mapping: Dict[str, str]
|
|
||||||
target_key: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@_transforms.register(MapTagValueConfig)
|
|
||||||
class MapTagValue:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tag_key: str,
|
|
||||||
value_mapping: Dict[str, str],
|
|
||||||
target_key: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.tag_key = tag_key
|
|
||||||
self.value_mapping = value_mapping
|
|
||||||
self.target_key = target_key
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
tags = []
|
|
||||||
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
if tag.key != self.tag_key:
|
|
||||||
tags.append(tag)
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = self.value_mapping.get(tag.value)
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
tags.append(tag)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.target_key is None:
|
|
||||||
tags.append(tag.model_copy(update=dict(value=value)))
|
|
||||||
else:
|
|
||||||
tags.append(
|
|
||||||
data.Tag(
|
|
||||||
key=self.target_key, # type: ignore
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: MapTagValueConfig):
|
|
||||||
return cls(
|
|
||||||
tag_key=config.tag_key,
|
|
||||||
value_mapping=config.value_mapping,
|
|
||||||
target_key=config.target_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyAllConfig(BaseConfig):
|
|
||||||
name: Literal["apply_all"] = "apply_all"
|
|
||||||
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@_transforms.register(ApplyAllConfig)
|
|
||||||
class ApplyAll:
|
|
||||||
def __init__(self, steps: List[SoundEventTransform]):
|
|
||||||
self.steps = steps
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
for step in self.steps:
|
|
||||||
sound_event_annotation = step(sound_event_annotation)
|
|
||||||
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: ApplyAllConfig):
|
|
||||||
steps = [build_sound_event_transform(step) for step in config.steps]
|
|
||||||
return cls(steps)
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventTransformConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
SetFrequencyBoundConfig,
|
|
||||||
ReplaceTagConfig,
|
|
||||||
MapTagValueConfig,
|
|
||||||
ApplyIfConfig,
|
|
||||||
ApplyAllConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_transform(
|
|
||||||
config: SoundEventTransformConfig,
|
|
||||||
) -> SoundEventTransform:
|
|
||||||
return _transforms.build(config)
|
|
||||||
|
|
||||||
|
|
||||||
def transform_clip_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
transform: SoundEventTransform,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=[
|
|
||||||
transform(sound_event)
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@ -1,62 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.typing.evaluate import MatchEvaluation
|
|
||||||
|
|
||||||
|
|
||||||
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
|
|
||||||
data = []
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
|
||||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
|
||||||
|
|
||||||
sound_event_annotation = match.sound_event_annotation
|
|
||||||
|
|
||||||
if sound_event_annotation is not None:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
assert geometry is not None
|
|
||||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
|
||||||
compute_bounds(geometry)
|
|
||||||
)
|
|
||||||
|
|
||||||
if match.pred_geometry is not None:
|
|
||||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
|
||||||
compute_bounds(match.pred_geometry)
|
|
||||||
)
|
|
||||||
|
|
||||||
data.append(
|
|
||||||
{
|
|
||||||
("recording", "uuid"): match.clip.recording.uuid,
|
|
||||||
("clip", "uuid"): match.clip.uuid,
|
|
||||||
("clip", "start_time"): match.clip.start_time,
|
|
||||||
("clip", "end_time"): match.clip.end_time,
|
|
||||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
|
||||||
if match.sound_event_annotation is not None
|
|
||||||
else None,
|
|
||||||
("gt", "class"): match.gt_class,
|
|
||||||
("gt", "det"): match.gt_det,
|
|
||||||
("gt", "start_time"): gt_start_time,
|
|
||||||
("gt", "end_time"): gt_end_time,
|
|
||||||
("gt", "low_freq"): gt_low_freq,
|
|
||||||
("gt", "high_freq"): gt_high_freq,
|
|
||||||
("pred", "score"): match.pred_score,
|
|
||||||
("pred", "class"): match.pred_class,
|
|
||||||
("pred", "class_score"): match.pred_class_score,
|
|
||||||
("pred", "start_time"): pred_start_time,
|
|
||||||
("pred", "end_time"): pred_end_time,
|
|
||||||
("pred", "low_freq"): pred_low_freq,
|
|
||||||
("pred", "high_freq"): pred_high_freq,
|
|
||||||
("match", "affinity"): match.affinity,
|
|
||||||
**{
|
|
||||||
("pred_class_score", key): value
|
|
||||||
for key, value in match.pred_class_scores.items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
|
||||||
return df
|
|
||||||
@ -1,100 +0,0 @@
|
|||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
|
||||||
from batdetect2.evaluate.match import match_all_predictions
|
|
||||||
from batdetect2.evaluate.metrics import (
|
|
||||||
ClassificationAccuracy,
|
|
||||||
ClassificationMeanAveragePrecision,
|
|
||||||
DetectionAveragePrecision,
|
|
||||||
)
|
|
||||||
from batdetect2.models import Model
|
|
||||||
from batdetect2.plotting.clips import build_audio_loader
|
|
||||||
from batdetect2.postprocess import get_raw_predictions
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.train.config import FullTrainingConfig
|
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
|
||||||
from batdetect2.train.train import build_val_loader
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
model: Model,
|
|
||||||
test_annotations: List[data.ClipAnnotation],
|
|
||||||
config: Optional[FullTrainingConfig] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
) -> Tuple[pd.DataFrame, dict]:
|
|
||||||
config = config or FullTrainingConfig()
|
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config.preprocess.audio)
|
|
||||||
|
|
||||||
preprocessor = build_preprocessor(config.preprocess)
|
|
||||||
|
|
||||||
targets = build_targets(config.targets)
|
|
||||||
|
|
||||||
labeller = build_clip_labeler(
|
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
config=config.train.labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
loader = build_val_loader(
|
|
||||||
test_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
labeller=labeller,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=config.train,
|
|
||||||
num_workers=num_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset: ValidationDataset = loader.dataset # type: ignore
|
|
||||||
|
|
||||||
clip_annotations = []
|
|
||||||
predictions = []
|
|
||||||
|
|
||||||
for batch in loader:
|
|
||||||
outputs = model.detector(batch.spec)
|
|
||||||
|
|
||||||
clip_annotations = [
|
|
||||||
dataset.clip_annotations[int(example_idx)]
|
|
||||||
for example_idx in batch.idx
|
|
||||||
]
|
|
||||||
|
|
||||||
predictions = get_raw_predictions(
|
|
||||||
outputs,
|
|
||||||
clips=[
|
|
||||||
clip_annotation.clip for clip_annotation in clip_annotations
|
|
||||||
],
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=model.postprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_annotations.extend(clip_annotations)
|
|
||||||
predictions.extend(predictions)
|
|
||||||
|
|
||||||
matches = match_all_predictions(
|
|
||||||
clip_annotations,
|
|
||||||
predictions,
|
|
||||||
targets=targets,
|
|
||||||
config=config.evaluation.match,
|
|
||||||
)
|
|
||||||
|
|
||||||
df = extract_matches_dataframe(matches)
|
|
||||||
|
|
||||||
metrics = [
|
|
||||||
DetectionAveragePrecision(),
|
|
||||||
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
|
||||||
]
|
|
||||||
|
|
||||||
results = {
|
|
||||||
name: value
|
|
||||||
for metric in metrics
|
|
||||||
for name, value in metric(matches).items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return df, results
|
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
from typing import List, Literal, Optional, Protocol, Tuple
|
from typing import List, Literal, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,6 +9,7 @@ from soundevent import data
|
|||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
from torch.multiprocessing import Pool
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
@ -282,7 +284,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
config = config or MatchConfig()
|
config = config or MatchConfig()
|
||||||
|
|
||||||
target_sound_events = [
|
target_sound_events = [
|
||||||
sound_event_annotation
|
targets.transform(sound_event_annotation)
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
if targets.filter(sound_event_annotation)
|
if targets.filter(sound_event_annotation)
|
||||||
and sound_event_annotation.sound_event.geometry is not None
|
and sound_event_annotation.sound_event.geometry is not None
|
||||||
@ -428,19 +430,17 @@ def match_all_predictions(
|
|||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
logger.info("Matching all annotations and predictions...")
|
logger.info("Matching all annotations and predictions...")
|
||||||
return [
|
with Pool() as p:
|
||||||
match
|
all_matches = p.starmap(
|
||||||
for clip_annotation, raw_predictions in zip(
|
partial(
|
||||||
clip_annotations,
|
match_sound_events_and_raw_predictions,
|
||||||
predictions,
|
targets=targets,
|
||||||
|
config=config,
|
||||||
|
),
|
||||||
|
zip(clip_annotations, predictions),
|
||||||
)
|
)
|
||||||
for match in match_sound_events_and_raw_predictions(
|
|
||||||
clip_annotation,
|
return [match for matches in all_matches for match in matches]
|
||||||
raw_predictions,
|
|
||||||
targets=targets,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -29,6 +29,7 @@ provided here.
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from lightning import LightningModule
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
@ -67,10 +68,7 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
|||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.typing.models import DetectionModel
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
|
||||||
DetectionsTensor,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
@ -104,16 +102,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
class Model(LightningModule):
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
@ -125,39 +114,43 @@ class Model(torch.nn.Module):
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: ModelConfig,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.config = config
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseConfig):
|
||||||
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
|
||||||
|
|
||||||
def build_model(config: Optional[ModelConfig] = None):
|
def build_model(config: Optional[ModelConfig] = None):
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
|
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
config=config.model,
|
config=config.model,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
config=config,
|
|
||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
|||||||
@ -56,7 +56,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class SelfAttentionConfig(BaseConfig):
|
class SelfAttentionConfig(BaseConfig):
|
||||||
name: Literal["SelfAttention"] = "SelfAttention"
|
block_type: Literal["SelfAttention"] = "SelfAttention"
|
||||||
attention_channels: int
|
attention_channels: int
|
||||||
temperature: float = 1
|
temperature: float = 1
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class SelfAttention(nn.Module):
|
|||||||
class ConvConfig(BaseConfig):
|
class ConvConfig(BaseConfig):
|
||||||
"""Configuration for a basic ConvBlock."""
|
"""Configuration for a basic ConvBlock."""
|
||||||
|
|
||||||
name: Literal["ConvBlock"] = "ConvBlock"
|
block_type: Literal["ConvBlock"] = "ConvBlock"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -300,7 +300,7 @@ class VerticalConv(nn.Module):
|
|||||||
class FreqCoordConvDownConfig(BaseConfig):
|
class FreqCoordConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvDownBlock."""
|
"""Configuration for a FreqCoordConvDownBlock."""
|
||||||
|
|
||||||
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -390,7 +390,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
class StandardConvDownConfig(BaseConfig):
|
class StandardConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvDownBlock."""
|
"""Configuration for a StandardConvDownBlock."""
|
||||||
|
|
||||||
name: Literal["StandardConvDown"] = "StandardConvDown"
|
block_type: Literal["StandardConvDown"] = "StandardConvDown"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -460,7 +460,7 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
class FreqCoordConvUpConfig(BaseConfig):
|
class FreqCoordConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvUpBlock."""
|
"""Configuration for a FreqCoordConvUpBlock."""
|
||||||
|
|
||||||
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -569,7 +569,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
class StandardConvUpConfig(BaseConfig):
|
class StandardConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvUpBlock."""
|
"""Configuration for a StandardConvUpBlock."""
|
||||||
|
|
||||||
name: Literal["StandardConvUp"] = "StandardConvUp"
|
block_type: Literal["StandardConvUp"] = "StandardConvUp"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -664,13 +664,13 @@ LayerConfig = Annotated[
|
|||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
"LayerGroupConfig",
|
"LayerGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
class LayerGroupConfig(BaseConfig):
|
class LayerGroupConfig(BaseConfig):
|
||||||
name: Literal["LayerGroup"] = "LayerGroup"
|
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||||
layers: List[LayerConfig]
|
layers: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
@ -686,7 +686,7 @@ def build_layer_from_config(
|
|||||||
parameters derived from the config and the current pipeline state
|
parameters derived from the config and the current pipeline state
|
||||||
(`input_height`, `in_channels`).
|
(`input_height`, `in_channels`).
|
||||||
|
|
||||||
It uses the `name` field within the `config` object to determine
|
It uses the `block_type` field within the `config` object to determine
|
||||||
which block class to instantiate.
|
which block class to instantiate.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -698,7 +698,7 @@ def build_layer_from_config(
|
|||||||
config : LayerConfig
|
config : LayerConfig
|
||||||
A Pydantic configuration object for the desired block (e.g., an
|
A Pydantic configuration object for the desired block (e.g., an
|
||||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||||
by its `name` field.
|
by its `block_type` field.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -711,11 +711,11 @@ def build_layer_from_config(
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If the `config.name` does not correspond to a known block type.
|
If the `config.block_type` does not correspond to a known block type.
|
||||||
ValueError
|
ValueError
|
||||||
If parameters derived from the config are invalid for the block.
|
If parameters derived from the config are invalid for the block.
|
||||||
"""
|
"""
|
||||||
if config.name == "ConvBlock":
|
if config.block_type == "ConvBlock":
|
||||||
return (
|
return (
|
||||||
ConvBlock(
|
ConvBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -727,7 +727,7 @@ def build_layer_from_config(
|
|||||||
input_height,
|
input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "FreqCoordConvDown":
|
if config.block_type == "FreqCoordConvDown":
|
||||||
return (
|
return (
|
||||||
FreqCoordConvDownBlock(
|
FreqCoordConvDownBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -740,7 +740,7 @@ def build_layer_from_config(
|
|||||||
input_height // 2,
|
input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "StandardConvDown":
|
if config.block_type == "StandardConvDown":
|
||||||
return (
|
return (
|
||||||
StandardConvDownBlock(
|
StandardConvDownBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -752,7 +752,7 @@ def build_layer_from_config(
|
|||||||
input_height // 2,
|
input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "FreqCoordConvUp":
|
if config.block_type == "FreqCoordConvUp":
|
||||||
return (
|
return (
|
||||||
FreqCoordConvUpBlock(
|
FreqCoordConvUpBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -765,7 +765,7 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "StandardConvUp":
|
if config.block_type == "StandardConvUp":
|
||||||
return (
|
return (
|
||||||
StandardConvUpBlock(
|
StandardConvUpBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -777,7 +777,7 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "SelfAttention":
|
if config.block_type == "SelfAttention":
|
||||||
return (
|
return (
|
||||||
SelfAttention(
|
SelfAttention(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -788,7 +788,7 @@ def build_layer_from_config(
|
|||||||
input_height,
|
input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "LayerGroup":
|
if config.block_type == "LayerGroup":
|
||||||
current_channels = in_channels
|
current_channels = in_channels
|
||||||
current_height = input_height
|
current_height = input_height
|
||||||
|
|
||||||
@ -804,4 +804,4 @@ def build_layer_from_config(
|
|||||||
|
|
||||||
return nn.Sequential(*blocks), current_channels, current_height
|
return nn.Sequential(*blocks), current_channels, current_height
|
||||||
|
|
||||||
raise NotImplementedError(f"Unknown block type {config.name}")
|
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||||
|
|||||||
@ -128,7 +128,7 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
BottleneckLayerConfig = Annotated[
|
BottleneckLayerConfig = Annotated[
|
||||||
Union[SelfAttentionConfig,],
|
Union[SelfAttentionConfig,],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
|
|||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
|
|||||||
layers : List[DecoderLayerConfig]
|
layers : List[DecoderLayerConfig]
|
||||||
An ordered list of configuration objects, each defining one layer or
|
An ordered list of configuration objects, each defining one layer or
|
||||||
block in the decoder sequence. Each item must be a valid block
|
block in the decoder sequence. Each item must be a valid block
|
||||||
config including a `name` field and necessary parameters like
|
config including a `block_type` field and necessary parameters like
|
||||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||||
The list must contain at least one layer.
|
The list must contain at least one layer.
|
||||||
"""
|
"""
|
||||||
@ -249,9 +249,9 @@ def build_decoder(
|
|||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
configuration is invalid (e.g., empty list, unknown `name`).
|
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If `build_layer_from_config` encounters an unknown `name`.
|
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_DECODER_CONFIG
|
config = config or DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
|
|||||||
An ordered list of configuration objects, each defining one layer or
|
An ordered list of configuration objects, each defining one layer or
|
||||||
block in the encoder sequence. Each item must be a valid block config
|
block in the encoder sequence. Each item must be a valid block config
|
||||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||||
`StandardConvDownConfig`) including a `name` field and necessary
|
`StandardConvDownConfig`) including a `block_type` field and necessary
|
||||||
parameters like `out_channels`. Input channels for each layer are
|
parameters like `out_channels`. Input channels for each layer are
|
||||||
inferred sequentially. The list must contain at least one layer.
|
inferred sequentially. The list must contain at least one layer.
|
||||||
"""
|
"""
|
||||||
@ -287,9 +287,9 @@ def build_encoder(
|
|||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
configuration is invalid (e.g., empty list, unknown `name`).
|
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If `build_layer_from_config` encounters an unknown `name`.
|
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||||
"""
|
"""
|
||||||
if in_channels <= 0 or input_height <= 0:
|
if in_channels <= 0 or input_height <= 0:
|
||||||
raise ValueError("in_channels and input_height must be positive.")
|
raise ValueError("in_channels and input_height must be positive.")
|
||||||
|
|||||||
@ -1,14 +1,6 @@
|
|||||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||||
|
|
||||||
from typing import (
|
from typing import Annotated, Callable, List, Literal, Optional, Union
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -314,7 +306,7 @@ class SpectrogramConfig(BaseConfig):
|
|||||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||||
transforms: Sequence[SpectrogramTransform] = Field(
|
transforms: List[SpectrogramTransform] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
PcenConfig(),
|
PcenConfig(),
|
||||||
SpectralMeanSubstractionConfig(),
|
SpectralMeanSubstractionConfig(),
|
||||||
|
|||||||
@ -1,26 +1,53 @@
|
|||||||
"""BatDetect2 Target Definition system."""
|
"""Main entry point for the BatDetect2 Target Definition subsystem.
|
||||||
|
|
||||||
|
This package (`batdetect2.targets`) provides the tools and configurations
|
||||||
|
necessary to define precisely what the BatDetect2 model should learn to detect,
|
||||||
|
classify, and localize from audio data. It involves several conceptual steps,
|
||||||
|
managed through configuration files and culminating in an executable pipeline:
|
||||||
|
|
||||||
|
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
|
||||||
|
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
|
||||||
|
3. **Transformation (`.transform`)**: Modifying tags (standardization,
|
||||||
|
derivation).
|
||||||
|
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
|
||||||
|
target position and size representations, and back.
|
||||||
|
5. **Class Definition (`.classes`)**: Mapping tags to target class names
|
||||||
|
(encoding) and mapping predicted names back to tags (decoding).
|
||||||
|
|
||||||
|
This module exposes the key components for users to configure and utilize this
|
||||||
|
target definition pipeline, primarily through the `TargetConfig` data structure
|
||||||
|
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
|
||||||
|
configured processing steps. The main way to create a functional `Targets`
|
||||||
|
object is via the `build_targets` or `load_targets` functions.
|
||||||
|
"""
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
SoundEventCondition,
|
|
||||||
build_sound_event_condition,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
DEFAULT_CLASSES,
|
ClassesConfig,
|
||||||
DEFAULT_GENERIC_CLASS,
|
|
||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
TargetClassConfig,
|
TargetClass,
|
||||||
|
build_generic_class_tags,
|
||||||
build_sound_event_decoder,
|
build_sound_event_decoder,
|
||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
|
load_classes_config,
|
||||||
|
load_decoder_from_config,
|
||||||
|
load_encoder_from_config,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.filtering import (
|
||||||
|
FilterConfig,
|
||||||
|
FilterRule,
|
||||||
|
SoundEventFilter,
|
||||||
|
build_sound_event_filter,
|
||||||
|
load_filter_config,
|
||||||
|
load_filter_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
@ -28,53 +55,114 @@ from batdetect2.targets.rois import (
|
|||||||
ROITargetMapper,
|
ROITargetMapper,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import call_type, individual
|
from batdetect2.targets.terms import (
|
||||||
|
TagInfo,
|
||||||
|
TermInfo,
|
||||||
|
TermRegistry,
|
||||||
|
call_type,
|
||||||
|
default_term_registry,
|
||||||
|
get_tag_from_info,
|
||||||
|
get_term_from_key,
|
||||||
|
individual,
|
||||||
|
register_term,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.transform import (
|
||||||
|
DerivationRegistry,
|
||||||
|
DeriveTagRule,
|
||||||
|
MapValueRule,
|
||||||
|
ReplaceRule,
|
||||||
|
SoundEventTransformation,
|
||||||
|
TransformConfig,
|
||||||
|
build_transformation_from_config,
|
||||||
|
default_derivation_registry,
|
||||||
|
get_derivation,
|
||||||
|
load_transformation_config,
|
||||||
|
load_transformation_from_config,
|
||||||
|
register_derivation,
|
||||||
|
)
|
||||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ClassesConfig",
|
||||||
"DEFAULT_TARGET_CONFIG",
|
"DEFAULT_TARGET_CONFIG",
|
||||||
|
"DeriveTagRule",
|
||||||
|
"FilterConfig",
|
||||||
|
"FilterRule",
|
||||||
|
"MapValueRule",
|
||||||
"AnchorBBoxMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
|
"ReplaceRule",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
"TargetClassConfig",
|
"SoundEventFilter",
|
||||||
|
"SoundEventTransformation",
|
||||||
|
"TagInfo",
|
||||||
|
"TargetClass",
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"Targets",
|
"Targets",
|
||||||
|
"TermInfo",
|
||||||
|
"TransformConfig",
|
||||||
|
"build_generic_class_tags",
|
||||||
"build_roi_mapper",
|
"build_roi_mapper",
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
|
"build_sound_event_filter",
|
||||||
|
"build_transformation_from_config",
|
||||||
"call_type",
|
"call_type",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
|
"get_derivation",
|
||||||
|
"get_tag_from_info",
|
||||||
|
"get_term_from_key",
|
||||||
"individual",
|
"individual",
|
||||||
|
"load_classes_config",
|
||||||
|
"load_decoder_from_config",
|
||||||
|
"load_encoder_from_config",
|
||||||
|
"load_filter_config",
|
||||||
|
"load_filter_from_config",
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
|
"load_transformation_config",
|
||||||
|
"load_transformation_from_config",
|
||||||
|
"register_derivation",
|
||||||
|
"register_term",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TargetConfig(BaseConfig):
|
class TargetConfig(BaseConfig):
|
||||||
detection_target: TargetClassConfig = Field(default=DEFAULT_GENERIC_CLASS)
|
"""Unified configuration for the entire target definition pipeline.
|
||||||
|
|
||||||
classification_targets: List[TargetClassConfig] = Field(
|
This model aggregates the configurations for semantic processing (filtering,
|
||||||
default_factory=lambda: DEFAULT_CLASSES
|
transformation, class definition) and geometric processing (ROI mapping).
|
||||||
|
It serves as the primary input for building a complete `Targets` object
|
||||||
|
via `build_targets` or `load_targets`.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
filtering : FilterConfig, optional
|
||||||
|
Configuration for filtering sound event annotations based on tags.
|
||||||
|
If None or omitted, no filtering is applied.
|
||||||
|
transforms : TransformConfig, optional
|
||||||
|
Configuration for transforming annotation tags
|
||||||
|
(mapping, derivation, etc.). If None or omitted, no tag transformations
|
||||||
|
are applied.
|
||||||
|
classes : ClassesConfig
|
||||||
|
Configuration defining the specific target classes, their tag matching
|
||||||
|
rules for encoding, their representative tags for decoding
|
||||||
|
(`output_tags`), and the definition of the generic class tags.
|
||||||
|
This section is mandatory.
|
||||||
|
roi : ROIConfig, optional
|
||||||
|
Configuration defining how geometric ROIs (e.g., bounding boxes) are
|
||||||
|
mapped to target representations (reference point, scaled size).
|
||||||
|
Controls `position`, `time_scale`, `frequency_scale`. If None or
|
||||||
|
omitted, default ROI mapping settings are used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
filtering: FilterConfig = Field(default_factory=FilterConfig)
|
||||||
|
transforms: TransformConfig = Field(default_factory=TransformConfig)
|
||||||
|
classes: ClassesConfig = Field(
|
||||||
|
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||||
)
|
)
|
||||||
|
|
||||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||||
|
|
||||||
@field_validator("classification_targets")
|
|
||||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
|
||||||
"""Ensure all defined class names are unique."""
|
|
||||||
names = [c.name for c in v]
|
|
||||||
|
|
||||||
if len(names) != len(set(names)):
|
|
||||||
name_counts = Counter(names)
|
|
||||||
duplicates = [
|
|
||||||
name for name, count in name_counts.items() if count > 1
|
|
||||||
]
|
|
||||||
raise ValueError(
|
|
||||||
"Class names must be unique. Found duplicates: "
|
|
||||||
f"{', '.join(duplicates)}"
|
|
||||||
)
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
@ -150,7 +238,8 @@ class Targets(TargetProtocol):
|
|||||||
roi_mapper: ROITargetMapper,
|
roi_mapper: ROITargetMapper,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
filter_fn: Optional[SoundEventCondition] = None,
|
filter_fn: Optional[SoundEventFilter] = None,
|
||||||
|
transform_fn: Optional[SoundEventTransformation] = None,
|
||||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the Targets object.
|
"""Initialize the Targets object.
|
||||||
@ -183,6 +272,7 @@ class Targets(TargetProtocol):
|
|||||||
self._filter_fn = filter_fn
|
self._filter_fn = filter_fn
|
||||||
self._encode_fn = encode_fn
|
self._encode_fn = encode_fn
|
||||||
self._decode_fn = decode_fn
|
self._decode_fn = decode_fn
|
||||||
|
self._transform_fn = transform_fn
|
||||||
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
||||||
|
|
||||||
for class_name in self._roi_mapper_overrides:
|
for class_name in self._roi_mapper_overrides:
|
||||||
@ -254,6 +344,27 @@ class Targets(TargetProtocol):
|
|||||||
"""
|
"""
|
||||||
return self._decode_fn(class_label)
|
return self._decode_fn(class_label)
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self, sound_event: data.SoundEventAnnotation
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
"""Apply the configured tag transformations to an annotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation whose tags should be transformed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.SoundEventAnnotation
|
||||||
|
A new annotation object with the transformed tags. If no
|
||||||
|
transformations were configured, the original annotation object is
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
if self._transform_fn:
|
||||||
|
return self._transform_fn(sound_event)
|
||||||
|
return sound_event
|
||||||
|
|
||||||
def encode_roi(
|
def encode_roi(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> tuple[Position, Size]:
|
) -> tuple[Position, Size]:
|
||||||
@ -319,14 +430,113 @@ class Targets(TargetProtocol):
|
|||||||
return self._roi_mapper.decode(position, size)
|
return self._roi_mapper.decode(position, size)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CLASSES = [
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis mystacinus")],
|
||||||
|
name="myomys",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis alcathoe")],
|
||||||
|
name="myoalc",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Eptesicus serotinus")],
|
||||||
|
name="eptser",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus nathusii")],
|
||||||
|
name="pipnat",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Barbastellus barbastellus")],
|
||||||
|
name="barbar",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis nattereri")],
|
||||||
|
name="myonat",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis daubentonii")],
|
||||||
|
name="myodau",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis brandtii")],
|
||||||
|
name="myobra",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
||||||
|
name="pippip",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis bechsteinii")],
|
||||||
|
name="myobec",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
||||||
|
name="pippyg",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
||||||
|
name="rhihip",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||||
|
name="nyclei",
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||||
|
name="rhifer",
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Plecotus auritus")],
|
||||||
|
name="pleaur",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Nyctalus noctula")],
|
||||||
|
name="nycnoc",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Plecotus austriacus")],
|
||||||
|
name="pleaus",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
|
||||||
|
classes=DEFAULT_CLASSES,
|
||||||
|
generic_class=[TagInfo(value="Bat")],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||||
classification_targets=DEFAULT_CLASSES,
|
filtering=FilterConfig(
|
||||||
detection_target=DEFAULT_GENERIC_CLASS,
|
rules=[
|
||||||
|
FilterRule(
|
||||||
|
match_type="all",
|
||||||
|
tags=[TagInfo(key="event", value="Echolocation")],
|
||||||
|
),
|
||||||
|
FilterRule(
|
||||||
|
match_type="exclude",
|
||||||
|
tags=[
|
||||||
|
TagInfo(key="event", value="Feeding"),
|
||||||
|
TagInfo(key="event", value="Unknown"),
|
||||||
|
TagInfo(key="event", value="Not Bat"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
classes=DEFAULT_CLASSES_CONFIG,
|
||||||
roi=AnchorBBoxMapperConfig(),
|
roi=AnchorBBoxMapperConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
def build_targets(
|
||||||
|
config: Optional[TargetConfig] = None,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||||
|
) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
This factory function takes the unified `TargetConfig` and constructs all
|
This factory function takes the unified `TargetConfig` and constructs all
|
||||||
@ -340,6 +550,13 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|||||||
----------
|
----------
|
||||||
config : TargetConfig
|
config : TargetConfig
|
||||||
The loaded and validated unified target configuration object.
|
The loaded and validated unified target configuration object.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to use for resolving term keys. Defaults
|
||||||
|
to the global `batdetect2.targets.terms.term_registry`.
|
||||||
|
derivation_registry : DerivationRegistry, optional
|
||||||
|
The DerivationRegistry instance to use for resolving derivation
|
||||||
|
function names. Defaults to the global
|
||||||
|
`batdetect2.targets.transform.derivation_registry`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -360,18 +577,40 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
filter_fn = build_sound_event_condition(config.detection_target.match_if)
|
filter_fn = (
|
||||||
encode_fn = build_sound_event_encoder(config.classification_targets)
|
build_sound_event_filter(
|
||||||
decode_fn = build_sound_event_decoder(config.classification_targets)
|
config.filtering,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
if config.filtering
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
encode_fn = build_sound_event_encoder(
|
||||||
|
config.classes,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
decode_fn = build_sound_event_decoder(
|
||||||
|
config.classes,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
transform_fn = (
|
||||||
|
build_transformation_from_config(
|
||||||
|
config.transforms,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
if config.transforms
|
||||||
|
else None
|
||||||
|
)
|
||||||
roi_mapper = build_roi_mapper(config.roi)
|
roi_mapper = build_roi_mapper(config.roi)
|
||||||
class_names = get_class_names_from_config(config.classification_targets)
|
class_names = get_class_names_from_config(config.classes)
|
||||||
|
generic_class_tags = build_generic_class_tags(
|
||||||
generic_class_tags = config.detection_target.assign_tags
|
config.classes,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
roi_overrides = {
|
roi_overrides = {
|
||||||
class_config.name: build_roi_mapper(class_config.roi)
|
class_config.name: build_roi_mapper(class_config.roi)
|
||||||
for class_config in config.classification_targets
|
for class_config in config.classes.classes
|
||||||
if class_config.roi is not None
|
if class_config.roi is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -382,6 +621,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|||||||
class_names=class_names,
|
class_names=class_names,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
generic_class_tags=generic_class_tags,
|
generic_class_tags=generic_class_tags,
|
||||||
|
transform_fn=transform_fn,
|
||||||
roi_mapper_overrides=roi_overrides,
|
roi_mapper_overrides=roi_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -389,6 +629,8 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|||||||
def load_targets(
|
def load_targets(
|
||||||
config_path: data.PathLike,
|
config_path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Load a Targets object directly from a configuration file.
|
"""Load a Targets object directly from a configuration file.
|
||||||
|
|
||||||
@ -403,6 +645,11 @@ def load_targets(
|
|||||||
field : str, optional
|
field : str, optional
|
||||||
Dot-separated path to a nested section within the file containing
|
Dot-separated path to a nested section within the file containing
|
||||||
the target configuration. If None, the entire file content is used.
|
the target configuration. If None, the entire file content is used.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to use. Defaults to the global default.
|
||||||
|
derivation_registry : DerivationRegistry, optional
|
||||||
|
The DerivationRegistry instance to use. Defaults to the global
|
||||||
|
default.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -423,7 +670,11 @@ def load_targets(
|
|||||||
config_path,
|
config_path,
|
||||||
field=field,
|
field=field,
|
||||||
)
|
)
|
||||||
return build_targets(config)
|
return build_targets(
|
||||||
|
config,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def iterate_encoded_sound_events(
|
def iterate_encoded_sound_events(
|
||||||
@ -439,6 +690,8 @@ def iterate_encoded_sound_events(
|
|||||||
if geometry is None:
|
if geometry is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
sound_event = targets.transform(sound_event)
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event)
|
class_name = targets.encode_class(sound_event)
|
||||||
position, size = targets.encode_roi(sound_event)
|
position, size = targets.encode_roi(sound_event)
|
||||||
|
|
||||||
|
|||||||
@ -1,172 +1,253 @@
|
|||||||
from typing import Dict, List, Optional
|
from collections import Counter
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple
|
||||||
|
|
||||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
from pydantic import Field, field_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.targets.rois import ROIMapperConfig
|
||||||
AllOfConfig,
|
from batdetect2.targets.terms import (
|
||||||
HasAllTagsConfig,
|
GENERIC_CLASS_KEY,
|
||||||
HasAnyTagConfig,
|
TagInfo,
|
||||||
HasTagConfig,
|
TermRegistry,
|
||||||
NotConfig,
|
default_term_registry,
|
||||||
SoundEventCondition,
|
get_tag_from_info,
|
||||||
SoundEventConditionConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
|
||||||
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
from batdetect2.typing.targets import SoundEventDecoder, SoundEventEncoder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"DEFAULT_SPECIES_LIST",
|
||||||
|
"build_generic_class_tags",
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
|
"load_classes_config",
|
||||||
|
"load_decoder_from_config",
|
||||||
|
"load_encoder_from_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TargetClassConfig(BaseConfig):
|
DEFAULT_SPECIES_LIST = [
|
||||||
"""Defines a target class of sound events."""
|
"Barbastella barbastellus",
|
||||||
|
"Eptesicus serotinus",
|
||||||
|
"Myotis alcathoe",
|
||||||
|
"Myotis bechsteinii",
|
||||||
|
"Myotis brandtii",
|
||||||
|
"Myotis daubentonii",
|
||||||
|
"Myotis mystacinus",
|
||||||
|
"Myotis nattereri",
|
||||||
|
"Nyctalus leisleri",
|
||||||
|
"Nyctalus noctula",
|
||||||
|
"Pipistrellus nathusii",
|
||||||
|
"Pipistrellus pipistrellus",
|
||||||
|
"Pipistrellus pygmaeus",
|
||||||
|
"Plecotus auritus",
|
||||||
|
"Plecotus austriacus",
|
||||||
|
"Rhinolophus ferrumequinum",
|
||||||
|
"Rhinolophus hipposideros",
|
||||||
|
]
|
||||||
|
"""A default list of common bat species names found in the UK."""
|
||||||
|
|
||||||
|
|
||||||
|
class TargetClass(BaseConfig):
|
||||||
|
"""Defines criteria for encoding annotations and decoding predictions.
|
||||||
|
|
||||||
|
Each instance represents one potential output class for the classification
|
||||||
|
model. It specifies:
|
||||||
|
1. A unique `name` for the class.
|
||||||
|
2. The tag conditions (`tags` and `match_type`) an annotation must meet to
|
||||||
|
be assigned this class name during training data preparation (encoding).
|
||||||
|
3. An optional, alternative set of tags (`output_tags`) to be used when
|
||||||
|
converting a model's prediction of this class name back into annotation
|
||||||
|
tags (decoding).
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The unique name assigned to this target class (e.g., 'pippip',
|
||||||
|
'myodau', 'noise'). This name is used as the label during model
|
||||||
|
training and is the expected output from the model's prediction.
|
||||||
|
Should be unique across all TargetClass definitions in a configuration.
|
||||||
|
tags : List[TagInfo]
|
||||||
|
A list of one or more tags (defined using `TagInfo`) used to identify
|
||||||
|
if an existing annotation belongs to this class during encoding (data
|
||||||
|
preparation for training). The `match_type` attribute determines how
|
||||||
|
these tags are evaluated.
|
||||||
|
match_type : Literal["all", "any"], default="all"
|
||||||
|
Determines how the `tags` list is evaluated during encoding:
|
||||||
|
- "all": The annotation must have *all* the tags listed to match.
|
||||||
|
- "any": The annotation must have *at least one* of the tags listed
|
||||||
|
to match.
|
||||||
|
output_tags: Optional[List[TagInfo]], default=None
|
||||||
|
An optional list of tags (defined using `TagInfo`) to be assigned to a
|
||||||
|
new annotation when the model predicts this class `name`. If `None`
|
||||||
|
(default), the tags listed in the `tags` field will be used for
|
||||||
|
decoding. If provided, this list overrides the `tags` field for the
|
||||||
|
purpose of decoding predictions back into meaningful annotation tags.
|
||||||
|
This allows, for example, training on broader categories but decoding
|
||||||
|
to more specific representative tags.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
tags: List[TagInfo] = Field(min_length=1)
|
||||||
condition_input: Optional[SoundEventConditionConfig] = Field(
|
match_type: Literal["all", "any"] = Field(default="all")
|
||||||
alias="match_if",
|
output_tags: Optional[List[TagInfo]] = None
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
tags: Optional[List[data.Tag]] = Field(default=None, exclude=True)
|
|
||||||
|
|
||||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
|
||||||
|
|
||||||
roi: Optional[ROIMapperConfig] = None
|
roi: Optional[ROIMapperConfig] = None
|
||||||
|
|
||||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
def _get_default_classes() -> List[TargetClass]:
|
||||||
def _process_tags(self) -> "TargetClassConfig":
|
"""Generate a list of default target classes.
|
||||||
if self.tags and self.condition_input:
|
|
||||||
raise ValueError("Use either 'tags' or 'match_if', not both.")
|
|
||||||
|
|
||||||
if self.condition_input is not None:
|
Returns
|
||||||
self._match_if = self.condition_input
|
-------
|
||||||
return self
|
List[TargetClass]
|
||||||
|
A list of TargetClass objects, one for each species in
|
||||||
|
DEFAULT_SPECIES_LIST. The class names are simplified versions of the
|
||||||
|
species names.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
TargetClass(
|
||||||
|
name=_get_default_class_name(value),
|
||||||
|
tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)],
|
||||||
|
)
|
||||||
|
for value in DEFAULT_SPECIES_LIST
|
||||||
|
]
|
||||||
|
|
||||||
if self.tags is None:
|
|
||||||
|
def _get_default_class_name(species: str) -> str:
|
||||||
|
"""Generate a default class name from a species name.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
species : str
|
||||||
|
The species name (e.g., "Myotis daubentonii").
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
A simplified class name (e.g., "myodau").
|
||||||
|
The genus and species names are converted to lowercase,
|
||||||
|
the first three letters of each are taken, and concatenated.
|
||||||
|
"""
|
||||||
|
genus, species = species.strip().split(" ")
|
||||||
|
return f"{genus.lower()[:3]}{species.lower()[:3]}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_generic_class() -> List[TagInfo]:
|
||||||
|
"""Generate the default list of TagInfo objects for the generic class.
|
||||||
|
|
||||||
|
Provides a default set of tags used to represent the generic "Bat" category
|
||||||
|
when decoding predictions that didn't match a specific class.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[TagInfo]
|
||||||
|
A list containing default TagInfo objects, typically representing
|
||||||
|
`call_type: Echolocation` and `order: Chiroptera`.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
TagInfo(key="call_type", value="Echolocation"),
|
||||||
|
TagInfo(key="order", value="Chiroptera"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ClassesConfig(BaseConfig):
|
||||||
|
"""Configuration defining target classes and the generic fallback category.
|
||||||
|
|
||||||
|
Holds the ordered list of specific target class definitions (`TargetClass`)
|
||||||
|
and defines the tags representing the generic category for sounds that pass
|
||||||
|
filtering but do not match any specific class.
|
||||||
|
|
||||||
|
The order of `TargetClass` objects in the `classes` list defines the
|
||||||
|
priority for classification during encoding. The system checks annotations
|
||||||
|
against these definitions sequentially and assigns the name of the *first*
|
||||||
|
matching class.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
classes : List[TargetClass]
|
||||||
|
An ordered list of specific target class definitions. The order
|
||||||
|
determines matching priority (first match wins). Defaults to a
|
||||||
|
standard set of classes via `get_default_classes`.
|
||||||
|
generic_class : List[TagInfo]
|
||||||
|
A list of tags defining the "generic" or "unclassified but relevant"
|
||||||
|
category (e.g., representing a generic 'Bat' call that wasn't
|
||||||
|
assigned to a specific species). These tags are typically assigned
|
||||||
|
during decoding when a sound event was detected and passed filtering
|
||||||
|
but did not match any specific class rule defined in the `classes` list.
|
||||||
|
Defaults to a standard set of tags via `get_default_generic_class`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If validation fails (e.g., non-unique class names in the `classes`
|
||||||
|
list).
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
- It is crucial that the `name` attribute of each `TargetClass` in the
|
||||||
|
`classes` list is unique. This configuration includes a validator to
|
||||||
|
enforce this uniqueness.
|
||||||
|
- The `generic_class` tags provide a baseline identity for relevant sounds
|
||||||
|
that don't fit into more specific defined categories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
classes: List[TargetClass] = Field(default_factory=_get_default_classes)
|
||||||
|
|
||||||
|
generic_class: List[TagInfo] = Field(
|
||||||
|
default_factory=_get_default_generic_class
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("classes")
|
||||||
|
def check_unique_class_names(cls, v: List[TargetClass]):
|
||||||
|
"""Ensure all defined class names are unique."""
|
||||||
|
names = [c.name for c in v]
|
||||||
|
|
||||||
|
if len(names) != len(set(names)):
|
||||||
|
name_counts = Counter(names)
|
||||||
|
duplicates = [
|
||||||
|
name for name, count in name_counts.items() if count > 1
|
||||||
|
]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Class '{self.name}' must have a 'tags' or 'match_if' rule."
|
"Class names must be unique. Found duplicates: "
|
||||||
|
f"{', '.join(duplicates)}"
|
||||||
)
|
)
|
||||||
|
return v
|
||||||
self._match_if = HasAllTagsConfig(tags=self.tags)
|
|
||||||
|
|
||||||
if not self.assign_tags:
|
|
||||||
self.assign_tags = self.tags
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def match_if(self) -> SoundEventConditionConfig:
|
|
||||||
return self._match_if
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_GENERIC_CLASS = TargetClassConfig(
|
def is_target_class(
|
||||||
name="bat",
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
match_if=AllOfConfig(
|
tags: Set[data.Tag],
|
||||||
conditions=[
|
match_all: bool = True,
|
||||||
HasTagConfig(tag=data.Tag(key="event", value="Echolocation")),
|
) -> bool:
|
||||||
NotConfig(
|
"""Check if a sound event annotation matches a set of required tags.
|
||||||
condition=HasAnyTagConfig(
|
|
||||||
tags=[
|
Parameters
|
||||||
data.Tag(key="event", value="Feeding"),
|
----------
|
||||||
data.Tag(key="event", value="Unknown"),
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
data.Tag(key="event", value="Not Bat"),
|
The annotation to check.
|
||||||
]
|
required_tags : Set[data.Tag]
|
||||||
)
|
A set of `soundevent.data.Tag` objects that define the class criteria.
|
||||||
),
|
match_all : bool, default=True
|
||||||
]
|
If True, checks if *all* `required_tags` are present in the
|
||||||
),
|
annotation's tags (subset check). If False, checks if *at least one*
|
||||||
assign_tags=[
|
of the `required_tags` is present (intersection check).
|
||||||
data.Tag(key="call_type", value="Echolocation"),
|
|
||||||
data.Tag(key="order", value="Chiroptera"),
|
Returns
|
||||||
],
|
-------
|
||||||
)
|
bool
|
||||||
|
True if the annotation meets the tag criteria, False otherwise.
|
||||||
|
"""
|
||||||
|
annotation_tags = set(sound_event_annotation.tags)
|
||||||
|
|
||||||
|
if match_all:
|
||||||
|
return tags <= annotation_tags
|
||||||
|
|
||||||
|
return bool(tags & annotation_tags)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSES = [
|
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
||||||
TargetClassConfig(
|
|
||||||
name="myomys",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis mystacinus")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myoalc",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis alcathoe")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="eptser",
|
|
||||||
tags=[data.Tag(key="class", value="Eptesicus serotinus")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="pipnat",
|
|
||||||
tags=[data.Tag(key="class", value="Pipistrellus nathusii")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="barbar",
|
|
||||||
tags=[data.Tag(key="class", value="Barbastellus barbastellus")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myonat",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis nattereri")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myodau",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis daubentonii")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myobra",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis brandtii")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="pippip",
|
|
||||||
tags=[data.Tag(key="class", value="Pipistrellus pipistrellus")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="myobec",
|
|
||||||
tags=[data.Tag(key="class", value="Myotis bechsteinii")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="pippyg",
|
|
||||||
tags=[data.Tag(key="class", value="Pipistrellus pygmaeus")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="rhihip",
|
|
||||||
tags=[data.Tag(key="class", value="Rhinolophus hipposideros")],
|
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="nyclei",
|
|
||||||
tags=[data.Tag(key="class", value="Nyctalus leisleri")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="rhifer",
|
|
||||||
tags=[data.Tag(key="class", value="Rhinolophus ferrumequinum")],
|
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="pleaur",
|
|
||||||
tags=[data.Tag(key="class", value="Plecotus auritus")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="nycnoc",
|
|
||||||
tags=[data.Tag(key="class", value="Nyctalus noctula")],
|
|
||||||
),
|
|
||||||
TargetClassConfig(
|
|
||||||
name="pleaus",
|
|
||||||
tags=[data.Tag(key="class", value="Plecotus austriacus")],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_class_names_from_config(configs: List[TargetClassConfig]) -> List[str]:
|
|
||||||
"""Extract the list of class names from a ClassesConfig object.
|
"""Extract the list of class names from a ClassesConfig object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -179,60 +260,340 @@ def get_class_names_from_config(configs: List[TargetClassConfig]) -> List[str]:
|
|||||||
List[str]
|
List[str]
|
||||||
An ordered list of unique class names defined in the configuration.
|
An ordered list of unique class names defined in the configuration.
|
||||||
"""
|
"""
|
||||||
return [class_info.name for class_info in configs]
|
return [class_info.name for class_info in config.classes]
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_with_multiple_classifiers(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]],
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Encode an annotation by checking against a list of classifiers.
|
||||||
|
|
||||||
|
Internal helper function used by the `SoundEventEncoder`. It iterates
|
||||||
|
through the provided list of (class_name, classifier_function) pairs.
|
||||||
|
Returns the name associated with the first classifier function that
|
||||||
|
returns True for the given annotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to encode.
|
||||||
|
classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]]
|
||||||
|
An ordered list where each tuple contains a class name and a function
|
||||||
|
that returns True if the annotation matches that class. The order
|
||||||
|
determines priority.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str or None
|
||||||
|
The name of the first matching class, or None if no classifier matches.
|
||||||
|
"""
|
||||||
|
for class_name, classifier in classifiers:
|
||||||
|
if classifier(sound_event_annotation):
|
||||||
|
return class_name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_encoder(
|
def build_sound_event_encoder(
|
||||||
configs: List[TargetClassConfig],
|
config: ClassesConfig,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
"""Build a sound event encoder function from the classes configuration."""
|
"""Build a sound event encoder function from the classes configuration.
|
||||||
conditions = {
|
|
||||||
class_config.name: build_sound_event_condition(class_config.match_if)
|
|
||||||
for class_config in configs
|
|
||||||
}
|
|
||||||
|
|
||||||
return SoundEventClassifier(conditions)
|
The returned encoder function iterates through the class definitions in the
|
||||||
|
order specified in the config. It assigns an annotation the name of the
|
||||||
|
first class definition it matches.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ClassesConfig
|
||||||
|
The loaded and validated classes configuration object.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance used to look up term keys specified in the
|
||||||
|
`TagInfo` objects within the configuration. Defaults to the global
|
||||||
|
`batdetect2.targets.terms.registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventEncoder
|
||||||
|
A callable function that takes a `SoundEventAnnotation` and returns
|
||||||
|
an optional string representing the matched class name, or None if no
|
||||||
|
class matches.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a term key specified in the configuration is not found in the
|
||||||
|
provided `term_registry`.
|
||||||
|
"""
|
||||||
|
binary_classifiers = [
|
||||||
|
(
|
||||||
|
class_info.name,
|
||||||
|
partial(
|
||||||
|
is_target_class,
|
||||||
|
tags={
|
||||||
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
|
for tag_info in class_info.tags
|
||||||
|
},
|
||||||
|
match_all=class_info.match_type == "all",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for class_info in config.classes
|
||||||
|
]
|
||||||
|
|
||||||
|
return partial(
|
||||||
|
_encode_with_multiple_classifiers,
|
||||||
|
classifiers=binary_classifiers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SoundEventClassifier:
|
def _decode_class(
|
||||||
def __init__(self, mapping: Dict[str, SoundEventCondition]):
|
name: str,
|
||||||
self.mapping = mapping
|
mapping: Dict[str, List[data.Tag]],
|
||||||
|
raise_on_error: bool = True,
|
||||||
|
) -> List[data.Tag]:
|
||||||
|
"""Decode a class name into a list of representative tags using a mapping.
|
||||||
|
|
||||||
def __call__(
|
Internal helper function used by the `SoundEventDecoder`. Looks up the
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
provided class `name` in the `mapping` dictionary.
|
||||||
) -> Optional[str]:
|
|
||||||
for name, condition in self.mapping.items():
|
Parameters
|
||||||
if condition(sound_event_annotation):
|
----------
|
||||||
return name
|
name : str
|
||||||
|
The class name to decode.
|
||||||
|
mapping : Dict[str, List[data.Tag]]
|
||||||
|
A dictionary mapping class names to lists of `soundevent.data.Tag`
|
||||||
|
objects.
|
||||||
|
raise_on_error : bool, default=True
|
||||||
|
If True, raises a ValueError if the `name` is not found in the
|
||||||
|
`mapping`. If False, returns an empty list if the `name` is not found.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.Tag]
|
||||||
|
The list of tags associated with the class name, or an empty list if
|
||||||
|
not found and `raise_on_error` is False.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `name` is not found in `mapping` and `raise_on_error` is True.
|
||||||
|
"""
|
||||||
|
if name not in mapping and raise_on_error:
|
||||||
|
raise ValueError(f"Class {name} not found in mapping.")
|
||||||
|
|
||||||
|
if name not in mapping:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return mapping[name]
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_decoder(
|
def build_sound_event_decoder(
|
||||||
configs: List[TargetClassConfig],
|
config: ClassesConfig,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Build a sound event decoder function from the classes configuration."""
|
"""Build a sound event decoder function from the classes configuration.
|
||||||
mapping = {
|
|
||||||
class_config.name: class_config.assign_tags for class_config in configs
|
Creates a callable `SoundEventDecoder` that maps a class name string
|
||||||
}
|
back to a list of representative `soundevent.data.Tag` objects based on
|
||||||
return TagDecoder(mapping, raise_on_unknown=raise_on_unmapped)
|
the `ClassesConfig`. It uses the `output_tags` field if provided in a
|
||||||
|
`TargetClass`, otherwise falls back to the `tags` field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ClassesConfig
|
||||||
|
The loaded and validated classes configuration object.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance used to look up term keys. Defaults to the
|
||||||
|
global `batdetect2.targets.terms.registry`.
|
||||||
|
raise_on_unmapped : bool, default=False
|
||||||
|
If True, the returned decoder function will raise a ValueError if asked
|
||||||
|
to decode a class name that is not in the configuration. If False, it
|
||||||
|
will return an empty list for unmapped names.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventDecoder
|
||||||
|
A callable function that takes a class name string and returns a list
|
||||||
|
of `soundevent.data.Tag` objects.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a term key specified in the configuration (`output_tags`, `tags`, or
|
||||||
|
`generic_class`) is not found in the provided `term_registry`.
|
||||||
|
"""
|
||||||
|
mapping = {}
|
||||||
|
for class_info in config.classes:
|
||||||
|
tags_to_use = (
|
||||||
|
class_info.output_tags
|
||||||
|
if class_info.output_tags is not None
|
||||||
|
else class_info.tags
|
||||||
|
)
|
||||||
|
mapping[class_info.name] = [
|
||||||
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
|
for tag_info in tags_to_use
|
||||||
|
]
|
||||||
|
|
||||||
|
return partial(
|
||||||
|
_decode_class,
|
||||||
|
mapping=mapping,
|
||||||
|
raise_on_error=raise_on_unmapped,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TagDecoder:
|
def build_generic_class_tags(
|
||||||
def __init__(
|
config: ClassesConfig,
|
||||||
self,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
mapping: Dict[str, List[data.Tag]],
|
) -> List[data.Tag]:
|
||||||
raise_on_unknown: bool = True,
|
"""Extract and build the list of tags for the generic class from config.
|
||||||
):
|
|
||||||
self.mapping = mapping
|
|
||||||
self.raise_on_unknown = raise_on_unknown
|
|
||||||
|
|
||||||
def __call__(self, class_name: str) -> List[data.Tag]:
|
Converts the list of `TagInfo` objects defined in `config.generic_class`
|
||||||
tags = self.mapping.get(class_name)
|
into a list of `soundevent.data.Tag` objects using the term registry.
|
||||||
|
|
||||||
if tags is None:
|
Parameters
|
||||||
if self.raise_on_unknown:
|
----------
|
||||||
raise ValueError("Invalid class name")
|
config : ClassesConfig
|
||||||
|
The loaded classes configuration object.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance for term lookups. Defaults to the global
|
||||||
|
`batdetect2.targets.terms.registry`.
|
||||||
|
|
||||||
tags = []
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.Tag]
|
||||||
|
The list of fully constructed tags representing the generic class.
|
||||||
|
|
||||||
return tags
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a term key specified in `config.generic_class` is not found in the
|
||||||
|
provided `term_registry`.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
|
for tag_info in config.generic_class
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def load_classes_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> ClassesConfig:
|
||||||
|
"""Load the target classes configuration from a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : str, optional
|
||||||
|
If the classes configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ClassesConfig
|
||||||
|
The loaded and validated classes configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the ClassesConfig schema
|
||||||
|
or if class names are not unique.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=ClassesConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def load_encoder_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> SoundEventEncoder:
|
||||||
|
"""Load a class encoder function directly from a configuration file.
|
||||||
|
|
||||||
|
This is a convenience function that combines loading the `ClassesConfig`
|
||||||
|
from a file and building the final `SoundEventEncoder` function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (e.g., YAML).
|
||||||
|
field : str, optional
|
||||||
|
If the classes configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance used for term lookups. Defaults to the
|
||||||
|
global `batdetect2.targets.terms.registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventEncoder
|
||||||
|
The final encoder function ready to classify annotations.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the ClassesConfig schema
|
||||||
|
or if class names are not unique.
|
||||||
|
KeyError
|
||||||
|
If a term key specified in the configuration is not found in the
|
||||||
|
provided `term_registry` during the build process.
|
||||||
|
"""
|
||||||
|
config = load_classes_config(path, field=field)
|
||||||
|
return build_sound_event_encoder(config, term_registry=term_registry)
|
||||||
|
|
||||||
|
|
||||||
|
def load_decoder_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
raise_on_unmapped: bool = False,
|
||||||
|
) -> SoundEventDecoder:
|
||||||
|
"""Load a class decoder function directly from a configuration file.
|
||||||
|
|
||||||
|
This is a convenience function that combines loading the `ClassesConfig`
|
||||||
|
from a file and building the final `SoundEventDecoder` function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (e.g., YAML).
|
||||||
|
field : str, optional
|
||||||
|
If the classes configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance used for term lookups. Defaults to the
|
||||||
|
global `batdetect2.targets.terms.registry`.
|
||||||
|
raise_on_unmapped : bool, default=False
|
||||||
|
If True, the returned decoder function will raise a ValueError if asked
|
||||||
|
to decode a class name that is not in the configuration. If False, it
|
||||||
|
will return an empty list for unmapped names.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventDecoder
|
||||||
|
The final decoder function ready to convert class names back into tags.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the ClassesConfig schema
|
||||||
|
or if class names are not unique.
|
||||||
|
KeyError
|
||||||
|
If a term key specified in the configuration is not found in the
|
||||||
|
provided `term_registry` during the build process.
|
||||||
|
"""
|
||||||
|
config = load_classes_config(path, field=field)
|
||||||
|
return build_sound_event_decoder(
|
||||||
|
config,
|
||||||
|
term_registry=term_registry,
|
||||||
|
raise_on_unmapped=raise_on_unmapped,
|
||||||
|
)
|
||||||
|
|||||||
307
src/batdetect2/targets/filtering.py
Normal file
307
src/batdetect2/targets/filtering.py
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Literal, Optional, Set
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
TagInfo,
|
||||||
|
TermRegistry,
|
||||||
|
default_term_registry,
|
||||||
|
get_tag_from_info,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.targets import SoundEventFilter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FilterConfig",
|
||||||
|
"FilterRule",
|
||||||
|
"build_sound_event_filter",
|
||||||
|
"build_filter_from_rule",
|
||||||
|
"load_filter_config",
|
||||||
|
"load_filter_from_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterRule(BaseConfig):
|
||||||
|
"""Defines a single rule for filtering sound event annotations.
|
||||||
|
|
||||||
|
Based on the `match_type`, this rule checks if the tags associated with a
|
||||||
|
sound event annotation meet certain criteria relative to the `tags` list
|
||||||
|
defined in this rule.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
match_type : Literal["any", "all", "exclude", "equal"]
|
||||||
|
Determines how the `tags` list is used:
|
||||||
|
- "any": Pass if the annotation has at least one tag from the list.
|
||||||
|
- "all": Pass if the annotation has all tags from the list (it can
|
||||||
|
have others too).
|
||||||
|
- "exclude": Pass if the annotation has none of the tags from the list.
|
||||||
|
- "equal": Pass if the annotation's tags are exactly the same set as
|
||||||
|
provided in the list.
|
||||||
|
tags : List[TagInfo]
|
||||||
|
A list of tags (defined using TagInfo for configuration) that this
|
||||||
|
rule operates on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
match_type: Literal["any", "all", "exclude", "equal"]
|
||||||
|
tags: List[TagInfo]
|
||||||
|
|
||||||
|
|
||||||
|
def has_any_tag(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the annotation has at least one of the specified tags.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The set of tags to look for.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation has one or more tags from the specified set,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
|
return bool(tags & sound_event_tags)
|
||||||
|
|
||||||
|
|
||||||
|
def contains_tags(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the annotation contains all of the specified tags.
|
||||||
|
|
||||||
|
The annotation may have additional tags beyond those specified.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The set of tags that must all be present in the annotation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation's tags are a superset of the specified tags,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
|
return tags <= sound_event_tags
|
||||||
|
|
||||||
|
|
||||||
|
def does_not_have_tags(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
):
|
||||||
|
"""Check if the annotation has none of the specified tags.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The set of tags that must *not* be present in the annotation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation has zero tags in common with the specified set,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
return not has_any_tag(sound_event_annotation, tags)
|
||||||
|
|
||||||
|
|
||||||
|
def equal_tags(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the annotation's tags are exactly equal to the specified set.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The exact set of tags the annotation must have.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation's tags set is identical to the specified set,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
|
return tags == sound_event_tags
|
||||||
|
|
||||||
|
|
||||||
|
def build_filter_from_rule(
|
||||||
|
rule: FilterRule,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> SoundEventFilter:
|
||||||
|
"""Creates a callable filter function from a single FilterRule.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rule : FilterRule
|
||||||
|
The filter rule configuration object.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventFilter
|
||||||
|
A function that takes a SoundEventAnnotation and returns True if it
|
||||||
|
passes the rule, False otherwise.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the rule contains an invalid `match_type`.
|
||||||
|
"""
|
||||||
|
tag_set = {
|
||||||
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
|
for tag_info in rule.tags
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.match_type == "any":
|
||||||
|
return partial(has_any_tag, tags=tag_set)
|
||||||
|
|
||||||
|
if rule.match_type == "all":
|
||||||
|
return partial(contains_tags, tags=tag_set)
|
||||||
|
|
||||||
|
if rule.match_type == "exclude":
|
||||||
|
return partial(does_not_have_tags, tags=tag_set)
|
||||||
|
|
||||||
|
if rule.match_type == "equal":
|
||||||
|
return partial(equal_tags, tags=tag_set)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid match type {rule.match_type}. Valid types "
|
||||||
|
"are: 'any', 'all', 'exclude' and 'equal'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _passes_all_filters(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
filters: List[SoundEventFilter],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the annotation passes all provided filters.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
filters : List[SoundEventFilter]
|
||||||
|
A list of filter functions to apply.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation passes all filters, False otherwise.
|
||||||
|
"""
|
||||||
|
for filter_fn in filters:
|
||||||
|
if not filter_fn(sound_event_annotation):
|
||||||
|
logging.debug(
|
||||||
|
f"Sound event annotation {sound_event_annotation.uuid} "
|
||||||
|
f"excluded due to rule {filter_fn}",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class FilterConfig(BaseConfig):
|
||||||
|
"""Configuration model for defining a list of filter rules.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
rules : List[FilterRule]
|
||||||
|
A list of FilterRule objects. An annotation must pass all rules in
|
||||||
|
this list to be considered valid by the filter built from this config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rules: List[FilterRule] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def build_sound_event_filter(
|
||||||
|
config: FilterConfig,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> SoundEventFilter:
|
||||||
|
"""Builds a merged filter function from a FilterConfig object.
|
||||||
|
|
||||||
|
Creates individual filter functions for each rule in the configuration
|
||||||
|
and merges them using AND logic.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : FilterConfig
|
||||||
|
The configuration object containing the list of filter rules.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventFilter
|
||||||
|
A single callable filter function that applies all defined rules.
|
||||||
|
"""
|
||||||
|
filters = [
|
||||||
|
build_filter_from_rule(rule, term_registry=term_registry)
|
||||||
|
for rule in config.rules
|
||||||
|
]
|
||||||
|
return partial(_passes_all_filters, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
|
def load_filter_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> FilterConfig:
|
||||||
|
"""Loads the filter configuration from a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : Optional[str], optional
|
||||||
|
If the filter configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
FilterConfig
|
||||||
|
The loaded and validated filter configuration object.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=FilterConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def load_filter_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> SoundEventFilter:
|
||||||
|
"""Loads filter configuration from a file and builds the filter function.
|
||||||
|
|
||||||
|
This is a convenience function that combines loading the configuration
|
||||||
|
and building the final callable filter function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : Optional[str], optional
|
||||||
|
If the filter configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventFilter
|
||||||
|
The final merged filter function ready to be used.
|
||||||
|
"""
|
||||||
|
config = load_filter_config(path=path, field=field)
|
||||||
|
return build_sound_event_filter(config, term_registry=term_registry)
|
||||||
@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
|
|||||||
*geometric* aspect of target definition from *semantic* classification.
|
*geometric* aspect of target definition from *semantic* classification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal, Optional, Tuple, Union
|
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -30,7 +30,7 @@ from batdetect2.configs import BaseConfig
|
|||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import Position, ROITargetMapper, Size
|
from batdetect2.typing.targets import Position, Size
|
||||||
from batdetect2.utils.arrays import spec_to_xarray
|
from batdetect2.utils.arrays import spec_to_xarray
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -83,6 +83,73 @@ DEFAULT_ANCHOR = "bottom-left"
|
|||||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||||
|
|
||||||
|
|
||||||
|
class ROITargetMapper(Protocol):
|
||||||
|
"""Protocol defining the interface for ROI-to-target mapping.
|
||||||
|
|
||||||
|
Specifies the `encode` and `decode` methods required for converting a
|
||||||
|
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||||
|
position and a size vector) and for recovering an approximate ROI from that
|
||||||
|
representation.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
A list containing the names of the dimensions in the `Size` array
|
||||||
|
returned by `encode` and expected by `decode`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names: List[str]
|
||||||
|
|
||||||
|
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||||
|
"""Encode a SoundEvent's geometry into a position and size.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEvent
|
||||||
|
The input sound event, which must have a geometry attribute.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[Position, Size]
|
||||||
|
A tuple containing:
|
||||||
|
- The reference position as (time, frequency) coordinates.
|
||||||
|
- A NumPy array with the calculated size dimensions.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the sound event does not have a geometry.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||||
|
"""Decode a position and size back into a geometric ROI.
|
||||||
|
|
||||||
|
Performs the inverse mapping: takes a reference position and size
|
||||||
|
dimensions and reconstructs a geometric representation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
position : Position
|
||||||
|
The reference position (time, frequency).
|
||||||
|
size : Size
|
||||||
|
NumPy array containing the size dimensions, matching the order
|
||||||
|
and meaning specified by `dimension_names`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Geometry
|
||||||
|
The reconstructed geometry, typically a `BoundingBox`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the `size` array has an unexpected shape or if reconstruction
|
||||||
|
fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class AnchorBBoxMapperConfig(BaseConfig):
|
class AnchorBBoxMapperConfig(BaseConfig):
|
||||||
"""Configuration for `AnchorBBoxMapper`.
|
"""Configuration for `AnchorBBoxMapper`.
|
||||||
|
|
||||||
@ -408,10 +475,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
|
|
||||||
ROIMapperConfig = Annotated[
|
ROIMapperConfig = Annotated[
|
||||||
Union[
|
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||||
AnchorBBoxMapperConfig,
|
|
||||||
PeakEnergyBBoxMapperConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""A discriminated union of all supported ROI mapper configurations.
|
"""A discriminated union of all supported ROI mapper configurations.
|
||||||
@ -489,7 +553,7 @@ def _build_bounding_box(
|
|||||||
) -> data.BoundingBox:
|
) -> data.BoundingBox:
|
||||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||||
|
|
||||||
Internal helper for `BBoxEncoder.decode`. Calculates the box
|
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||||
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||||
the input `pos` (time, freq) is located relative to the box (e.g.,
|
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||||
center, corner).
|
center, corner).
|
||||||
|
|||||||
@ -1,11 +1,34 @@
|
|||||||
"""Manages the vocabulary for defining training targets."""
|
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
||||||
|
|
||||||
|
This module provides the necessary tools to declare, register, and manage the
|
||||||
|
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
||||||
|
sub-package. It establishes a consistent vocabulary for filtering,
|
||||||
|
transforming, and classifying sound events based on their annotations (Tags).
|
||||||
|
|
||||||
|
The core component is the `TermRegistry`, which maps unique string keys
|
||||||
|
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
||||||
|
terms using simple, consistent keys in configuration files and code.
|
||||||
|
|
||||||
|
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
||||||
|
programmatically, or loaded from external configuration files (e.g., YAML).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from inspect import getmembers
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
|
from batdetect2.configs import load_config
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"call_type",
|
"call_type",
|
||||||
"individual",
|
"individual",
|
||||||
"data_source",
|
"data_source",
|
||||||
|
"get_tag_from_info",
|
||||||
|
"TermInfo",
|
||||||
|
"TagInfo",
|
||||||
]
|
]
|
||||||
|
|
||||||
# The default key used to reference the 'generic_class' term.
|
# The default key used to reference the 'generic_class' term.
|
||||||
@ -73,3 +96,430 @@ terms.register_term_set(
|
|||||||
),
|
),
|
||||||
override_existing=True,
|
override_existing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TermRegistry(Mapping[str, data.Term]):
|
||||||
|
"""Manages a registry mapping unique keys to Term definitions.
|
||||||
|
|
||||||
|
This class acts as the central repository for the vocabulary of terms
|
||||||
|
used within the target definition process. It allows registering terms
|
||||||
|
with simple string keys and retrieving them consistently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
||||||
|
"""Initializes the TermRegistry.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
terms : dict[str, soundevent.data.Term], optional
|
||||||
|
An optional dictionary of initial key-to-Term mappings
|
||||||
|
to populate the registry with. Defaults to an empty registry.
|
||||||
|
"""
|
||||||
|
self._terms: Dict[str, data.Term] = terms or {}
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> data.Term:
|
||||||
|
return self._terms[key]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._terms)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._terms)
|
||||||
|
|
||||||
|
def add_term(self, key: str, term: data.Term) -> None:
|
||||||
|
"""Adds a Term object to the registry with the specified key.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The unique string key to associate with the term.
|
||||||
|
term : soundevent.data.Term
|
||||||
|
The soundevent.data.Term object to register.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a term with the provided key already exists in the
|
||||||
|
registry.
|
||||||
|
"""
|
||||||
|
if key in self._terms:
|
||||||
|
raise KeyError("A term with the provided key already exists.")
|
||||||
|
|
||||||
|
self._terms[key] = term
|
||||||
|
|
||||||
|
def get_term(self, key: str) -> data.Term:
|
||||||
|
"""Retrieves a registered term by its unique key.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The unique string key of the term to retrieve.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Term
|
||||||
|
The corresponding soundevent.data.Term object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If no term with the specified key is found, with a
|
||||||
|
helpful message suggesting listing available keys.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._terms[key]
|
||||||
|
except KeyError as err:
|
||||||
|
raise KeyError(
|
||||||
|
"No term found for key "
|
||||||
|
f"'{key}'. Ensure it is registered or loaded. "
|
||||||
|
f"Available keys: {', '.join(self.get_keys())}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
def add_custom_term(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
uri: Optional[str] = None,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
definition: Optional[str] = None,
|
||||||
|
) -> data.Term:
|
||||||
|
"""Creates a new Term from attributes and adds it to the registry.
|
||||||
|
|
||||||
|
This is useful for defining terms directly in code or when loading
|
||||||
|
from configuration files where only attributes are provided.
|
||||||
|
|
||||||
|
If optional fields (`name`, `label`, `definition`) are not provided,
|
||||||
|
reasonable defaults are used (`key` for name/label, "Unknown" for
|
||||||
|
definition).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The unique string key for the new term.
|
||||||
|
name : str, optional
|
||||||
|
The name for the new term (defaults to `key`).
|
||||||
|
uri : str, optional
|
||||||
|
The URI for the new term (optional).
|
||||||
|
label : str, optional
|
||||||
|
The display label for the new term (defaults to `key`).
|
||||||
|
definition : str, optional
|
||||||
|
The definition for the new term (defaults to "Unknown").
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Term
|
||||||
|
The newly created and registered soundevent.data.Term object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a term with the provided key already exists.
|
||||||
|
"""
|
||||||
|
term = data.Term(
|
||||||
|
name=name or key,
|
||||||
|
label=label or key,
|
||||||
|
uri=uri,
|
||||||
|
definition=definition or "Unknown",
|
||||||
|
)
|
||||||
|
self.add_term(key, term)
|
||||||
|
return term
|
||||||
|
|
||||||
|
def get_keys(self) -> List[str]:
|
||||||
|
"""Returns a list of all keys currently registered.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[str]
|
||||||
|
A list of strings representing the keys of all registered terms.
|
||||||
|
"""
|
||||||
|
return list(self._terms.keys())
|
||||||
|
|
||||||
|
def get_terms(self) -> List[data.Term]:
|
||||||
|
"""Returns a list of all registered terms.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[soundevent.data.Term]
|
||||||
|
A list containing all registered Term objects.
|
||||||
|
"""
|
||||||
|
return list(self._terms.values())
|
||||||
|
|
||||||
|
def remove_key(self, key: str) -> None:
|
||||||
|
del self._terms[key]
|
||||||
|
|
||||||
|
|
||||||
|
default_term_registry = TermRegistry(
|
||||||
|
terms=dict(
|
||||||
|
[
|
||||||
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
|
("event", call_type),
|
||||||
|
("species", terms.scientific_name),
|
||||||
|
("individual", individual),
|
||||||
|
("data_source", data_source),
|
||||||
|
(GENERIC_CLASS_KEY, generic_class),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""The default, globally accessible TermRegistry instance.
|
||||||
|
|
||||||
|
It is pre-populated with standard terms from `soundevent.terms` and common
|
||||||
|
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
||||||
|
Functions in this module use this registry by default unless another instance
|
||||||
|
is explicitly passed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_term_from_key(
|
||||||
|
key: str,
|
||||||
|
term_registry: Optional[TermRegistry] = None,
|
||||||
|
) -> data.Term:
|
||||||
|
"""Convenience function to retrieve a term by key from a registry.
|
||||||
|
|
||||||
|
Uses the global default registry unless a specific `term_registry`
|
||||||
|
instance is provided.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The unique key of the term to retrieve.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to search in. Defaults to the global
|
||||||
|
`registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Term
|
||||||
|
The corresponding soundevent.data.Term object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If the key is not found in the specified registry.
|
||||||
|
"""
|
||||||
|
term = terms.get_term(key)
|
||||||
|
|
||||||
|
if term:
|
||||||
|
return term
|
||||||
|
|
||||||
|
term_registry = term_registry or default_term_registry
|
||||||
|
return term_registry.get_term(key)
|
||||||
|
|
||||||
|
|
||||||
|
def get_term_keys(
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Convenience function to get all registered keys from a registry.
|
||||||
|
|
||||||
|
Uses the global default registry unless a specific `term_registry`
|
||||||
|
instance is provided.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[str]
|
||||||
|
A list of strings representing the keys of all registered terms.
|
||||||
|
"""
|
||||||
|
return term_registry.get_keys()
|
||||||
|
|
||||||
|
|
||||||
|
def get_terms(
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> List[data.Term]:
|
||||||
|
"""Convenience function to get all registered terms from a registry.
|
||||||
|
|
||||||
|
Uses the global default registry unless a specific `term_registry`
|
||||||
|
instance is provided.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[soundevent.data.Term]
|
||||||
|
A list containing all registered Term objects.
|
||||||
|
"""
|
||||||
|
return term_registry.get_terms()
|
||||||
|
|
||||||
|
|
||||||
|
class TagInfo(BaseModel):
|
||||||
|
"""Represents information needed to define a specific Tag.
|
||||||
|
|
||||||
|
This model is typically used in configuration files (e.g., YAML) to
|
||||||
|
specify tags used for filtering, target class definition, or associating
|
||||||
|
tags with output classes. It links a tag value to a term definition
|
||||||
|
via the term's registry key.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
value : str
|
||||||
|
The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
||||||
|
key : str, default="class"
|
||||||
|
The key (alias) of the term associated with this tag, as
|
||||||
|
registered in the TermRegistry. Defaults to "class", implying
|
||||||
|
it represents a classification target label by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
value: str
|
||||||
|
key: str = GENERIC_CLASS_KEY
|
||||||
|
|
||||||
|
|
||||||
|
def get_tag_from_info(
|
||||||
|
tag_info: TagInfo,
|
||||||
|
term_registry: Optional[TermRegistry] = None,
|
||||||
|
) -> data.Tag:
|
||||||
|
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||||
|
|
||||||
|
Looks up the term using the key in the provided `tag_info` from the
|
||||||
|
specified registry and constructs a Tag object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tag_info : TagInfo
|
||||||
|
The TagInfo object containing the value and term key.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to use for term lookup. Defaults to the
|
||||||
|
global `registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Tag
|
||||||
|
A soundevent.data.Tag object corresponding to the input info.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If the term key specified in `tag_info.key` is not found
|
||||||
|
in the registry.
|
||||||
|
"""
|
||||||
|
term_registry = term_registry or default_term_registry
|
||||||
|
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||||
|
return data.Tag(term=term, value=tag_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TermInfo(BaseModel):
|
||||||
|
"""Represents the definition of a Term within a configuration file.
|
||||||
|
|
||||||
|
This model allows users to define custom terms directly in configuration
|
||||||
|
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
||||||
|
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The unique key (alias) that will be used to register and
|
||||||
|
reference this term.
|
||||||
|
label : str, optional
|
||||||
|
The optional display label for the term. Defaults to `key`
|
||||||
|
if not provided during registration.
|
||||||
|
name : str, optional
|
||||||
|
The optional formal name for the term. Defaults to `key`
|
||||||
|
if not provided during registration.
|
||||||
|
uri : str, optional
|
||||||
|
The optional URI identifying the term (e.g., from a standard
|
||||||
|
vocabulary).
|
||||||
|
definition : str, optional
|
||||||
|
The optional textual definition of the term. Defaults to
|
||||||
|
"Unknown" if not provided during registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
label: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
uri: Optional[str] = None
|
||||||
|
definition: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TermConfig(BaseModel):
|
||||||
|
"""Pydantic schema for loading a list of term definitions from config.
|
||||||
|
|
||||||
|
This model typically corresponds to a section in a configuration file
|
||||||
|
(e.g., YAML) containing a list of term definitions to be registered.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
terms : list[TermInfo]
|
||||||
|
A list of TermInfo objects, each defining a term to be
|
||||||
|
registered. Defaults to an empty list.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
Example YAML structure:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
terms:
|
||||||
|
- key: species
|
||||||
|
uri: dwc:scientificName
|
||||||
|
label: Scientific Name
|
||||||
|
- key: my_custom_term
|
||||||
|
name: My Custom Term
|
||||||
|
definition: Describes a specific project attribute.
|
||||||
|
# ... more TermInfo definitions
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
terms: List[TermInfo] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_terms_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> Dict[str, data.Term]:
|
||||||
|
"""Loads term definitions from a configuration file and registers them.
|
||||||
|
|
||||||
|
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
||||||
|
extracts the list of TermInfo definitions, and adds each one as a
|
||||||
|
custom term to the specified TermRegistry instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
The path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Optional key indicating a specific section within the config
|
||||||
|
file where the 'terms' list is located. If None, expects the
|
||||||
|
list directly at the top level or within a structure matching
|
||||||
|
TermConfig schema.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to add the loaded terms to. Defaults to
|
||||||
|
the global `registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[str, soundevent.data.Term]
|
||||||
|
A dictionary mapping the keys of the newly added terms to their
|
||||||
|
corresponding Term objects.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the TermConfig schema.
|
||||||
|
KeyError
|
||||||
|
If a term key loaded from the config conflicts with a key
|
||||||
|
already present in the registry.
|
||||||
|
"""
|
||||||
|
data = load_config(path, schema=TermConfig, field=field)
|
||||||
|
return {
|
||||||
|
info.key: term_registry.add_custom_term(
|
||||||
|
info.key,
|
||||||
|
name=info.name,
|
||||||
|
uri=info.uri,
|
||||||
|
label=info.label,
|
||||||
|
definition=info.definition,
|
||||||
|
)
|
||||||
|
for info in data.terms
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_term(
|
||||||
|
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
||||||
|
) -> None:
|
||||||
|
registry.add_term(key, term)
|
||||||
|
|||||||
708
src/batdetect2/targets/transform.py
Normal file
708
src/batdetect2/targets/transform.py
Normal file
@ -0,0 +1,708 @@
|
|||||||
|
import importlib
|
||||||
|
from functools import partial
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
TagInfo,
|
||||||
|
TermRegistry,
|
||||||
|
get_tag_from_info,
|
||||||
|
get_term_from_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DerivationRegistry",
|
||||||
|
"DeriveTagRule",
|
||||||
|
"MapValueRule",
|
||||||
|
"ReplaceRule",
|
||||||
|
"SoundEventTransformation",
|
||||||
|
"TransformConfig",
|
||||||
|
"build_transform_from_rule",
|
||||||
|
"build_transformation_from_config",
|
||||||
|
"default_derivation_registry",
|
||||||
|
"get_derivation",
|
||||||
|
"load_transformation_config",
|
||||||
|
"load_transformation_from_config",
|
||||||
|
"register_derivation",
|
||||||
|
]
|
||||||
|
|
||||||
|
SoundEventTransformation = Callable[
|
||||||
|
[data.SoundEventAnnotation], data.SoundEventAnnotation
|
||||||
|
]
|
||||||
|
"""Type alias for a sound event transformation function.
|
||||||
|
|
||||||
|
A function that accepts a sound event annotation object and returns a
|
||||||
|
(potentially) modified sound event annotation object. Transformations
|
||||||
|
should generally return a copy of the annotation rather than modifying
|
||||||
|
it in place.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Derivation = Callable[[str], str]
|
||||||
|
"""Type alias for a derivation function.
|
||||||
|
|
||||||
|
A function that accepts a single string (typically a tag value) and returns
|
||||||
|
a new string (the derived value).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MapValueRule(BaseConfig):
|
||||||
|
"""Configuration for mapping specific values of a source term.
|
||||||
|
|
||||||
|
This rule replaces tags matching a specific term and one of the
|
||||||
|
original values with a new tag (potentially having a different term)
|
||||||
|
containing the corresponding replacement value. Useful for standardizing
|
||||||
|
or grouping tag values.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
rule_type : Literal["map_value"]
|
||||||
|
Discriminator field identifying this rule type.
|
||||||
|
source_term_key : str
|
||||||
|
The key (registered in `TermRegistry`) of the term whose tags' values
|
||||||
|
should be checked against the `value_mapping`.
|
||||||
|
value_mapping : Dict[str, str]
|
||||||
|
A dictionary mapping original string values to replacement string
|
||||||
|
values. Only tags whose value is a key in this dictionary will be
|
||||||
|
affected.
|
||||||
|
target_term_key : str, optional
|
||||||
|
The key (registered in `TermRegistry`) for the term of the *output*
|
||||||
|
tag. If None (default), the output tag uses the same term as the
|
||||||
|
source (`source_term_key`). If provided, the term of the affected
|
||||||
|
tag is changed to this target term upon replacement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rule_type: Literal["map_value"] = "map_value"
|
||||||
|
source_term_key: str
|
||||||
|
value_mapping: Dict[str, str]
|
||||||
|
target_term_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def map_value_transform(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
source_term: data.Term,
|
||||||
|
target_term: data.Term,
|
||||||
|
mapping: Dict[str, str],
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
"""Apply a value mapping transformation to an annotation's tags.
|
||||||
|
|
||||||
|
Iterates through the annotation's tags. If a tag matches the `source_term`
|
||||||
|
and its value is found in the `mapping`, it is replaced by a new tag with
|
||||||
|
the `target_term` and the mapped value. Other tags are kept unchanged.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to transform.
|
||||||
|
source_term : data.Term
|
||||||
|
The term of tags whose values should be mapped.
|
||||||
|
target_term : data.Term
|
||||||
|
The term to use for the newly created tags after mapping.
|
||||||
|
mapping : Dict[str, str]
|
||||||
|
The dictionary mapping original values to new values.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.SoundEventAnnotation
|
||||||
|
A new annotation object with the transformed tags.
|
||||||
|
"""
|
||||||
|
tags = []
|
||||||
|
|
||||||
|
for tag in sound_event_annotation.tags:
|
||||||
|
if tag.term != source_term or tag.value not in mapping:
|
||||||
|
tags.append(tag)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_value = mapping[tag.value]
|
||||||
|
tags.append(data.Tag(term=target_term, value=new_value))
|
||||||
|
|
||||||
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
|
|
||||||
|
class DeriveTagRule(BaseConfig):
|
||||||
|
"""Configuration for deriving a new tag from an existing tag's value.
|
||||||
|
|
||||||
|
This rule applies a specified function (`derivation_function`) to the
|
||||||
|
value of tags matching the `source_term_key`. It then adds a new tag
|
||||||
|
with the `target_term_key` and the derived value.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
rule_type : Literal["derive_tag"]
|
||||||
|
Discriminator field identifying this rule type.
|
||||||
|
source_term_key : str
|
||||||
|
The key (registered in `TermRegistry`) of the term whose tag values
|
||||||
|
will be used as input to the derivation function.
|
||||||
|
derivation_function : str
|
||||||
|
The name/key identifying the derivation function to use. This can be
|
||||||
|
a key registered in the `DerivationRegistry` or, if
|
||||||
|
`import_derivation` is True, a full Python path like
|
||||||
|
`'my_module.my_submodule.my_function'`.
|
||||||
|
target_term_key : str, optional
|
||||||
|
The key (registered in `TermRegistry`) for the term of the new tag
|
||||||
|
that will be created with the derived value. If None (default), the
|
||||||
|
derived tag uses the same term as the source (`source_term_key`),
|
||||||
|
effectively performing an in-place value transformation.
|
||||||
|
import_derivation : bool, default=False
|
||||||
|
If True, treat `derivation_function` as a Python import path and
|
||||||
|
attempt to dynamically import it if not found in the registry.
|
||||||
|
Requires the function to be accessible in the Python environment.
|
||||||
|
keep_source : bool, default=True
|
||||||
|
If True, the original source tag (whose value was used for derivation)
|
||||||
|
is kept in the annotation's tag list alongside the newly derived tag.
|
||||||
|
If False, the original source tag is removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rule_type: Literal["derive_tag"] = "derive_tag"
|
||||||
|
source_term_key: str
|
||||||
|
derivation_function: str
|
||||||
|
target_term_key: Optional[str] = None
|
||||||
|
import_derivation: bool = False
|
||||||
|
keep_source: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def derivation_tag_transform(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
source_term: data.Term,
|
||||||
|
target_term: data.Term,
|
||||||
|
derivation: Derivation,
|
||||||
|
keep_source: bool = True,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
"""Apply a derivation transformation to an annotation's tags.
|
||||||
|
|
||||||
|
Iterates through the annotation's tags. For each tag matching the
|
||||||
|
`source_term`, its value is passed to the `derivation` function.
|
||||||
|
A new tag is created with the `target_term` and the derived value,
|
||||||
|
and added to the output tag list. The original source tag is kept
|
||||||
|
or discarded based on `keep_source`. Other tags are kept unchanged.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to transform.
|
||||||
|
source_term : data.Term
|
||||||
|
The term of tags whose values serve as input for the derivation.
|
||||||
|
target_term : data.Term
|
||||||
|
The term to use for the newly created derived tags.
|
||||||
|
derivation : Derivation
|
||||||
|
The function to apply to the source tag's value.
|
||||||
|
keep_source : bool, default=True
|
||||||
|
Whether to keep the original source tag in the output.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.SoundEventAnnotation
|
||||||
|
A new annotation object with the transformed tags (including derived
|
||||||
|
ones).
|
||||||
|
"""
|
||||||
|
tags = []
|
||||||
|
|
||||||
|
for tag in sound_event_annotation.tags:
|
||||||
|
if tag.term != source_term:
|
||||||
|
tags.append(tag)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if keep_source:
|
||||||
|
tags.append(tag)
|
||||||
|
|
||||||
|
new_value = derivation(tag.value)
|
||||||
|
tags.append(data.Tag(term=target_term, value=new_value))
|
||||||
|
|
||||||
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceRule(BaseConfig):
|
||||||
|
"""Configuration for exactly replacing one specific tag with another.
|
||||||
|
|
||||||
|
This rule looks for an exact match of the `original` tag (both term and
|
||||||
|
value) and replaces it with the specified `replacement` tag.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
rule_type : Literal["replace"]
|
||||||
|
Discriminator field identifying this rule type.
|
||||||
|
original : TagInfo
|
||||||
|
The exact tag to search for, defined using its value and term key.
|
||||||
|
replacement : TagInfo
|
||||||
|
The tag to substitute in place of the original tag, defined using
|
||||||
|
its value and term key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rule_type: Literal["replace"] = "replace"
|
||||||
|
original: TagInfo
|
||||||
|
replacement: TagInfo
|
||||||
|
|
||||||
|
|
||||||
|
def replace_tag_transform(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
source: data.Tag,
|
||||||
|
target: data.Tag,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
"""Apply an exact tag replacement transformation.
|
||||||
|
|
||||||
|
Iterates through the annotation's tags. If a tag exactly matches the
|
||||||
|
`source` tag, it is replaced by the `target` tag. Other tags are kept
|
||||||
|
unchanged.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to transform.
|
||||||
|
source : data.Tag
|
||||||
|
The exact tag to find and replace.
|
||||||
|
target : data.Tag
|
||||||
|
The tag to replace the source tag with.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.SoundEventAnnotation
|
||||||
|
A new annotation object with the replaced tag (if found).
|
||||||
|
"""
|
||||||
|
tags = []
|
||||||
|
|
||||||
|
for tag in sound_event_annotation.tags:
|
||||||
|
if tag == source:
|
||||||
|
tags.append(target)
|
||||||
|
else:
|
||||||
|
tags.append(tag)
|
||||||
|
|
||||||
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformConfig(BaseConfig):
|
||||||
|
"""Configuration model for defining a sequence of transformation rules.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
|
||||||
|
A list of transformation rules to apply. The rules are applied
|
||||||
|
sequentially in the order they appear in the list. The output of
|
||||||
|
one rule becomes the input for the next. The `rule_type` field
|
||||||
|
discriminates between the different rule models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rules: List[
|
||||||
|
Annotated[
|
||||||
|
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||||
|
Field(discriminator="rule_type"),
|
||||||
|
]
|
||||||
|
] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DerivationRegistry(Mapping[str, Derivation]):
|
||||||
|
"""A registry for managing named derivation functions.
|
||||||
|
|
||||||
|
Derivation functions are callables that take a string value and return
|
||||||
|
a transformed string value, used by `DeriveTagRule`. This registry
|
||||||
|
allows functions to be registered with a key and retrieved later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize an empty DerivationRegistry."""
|
||||||
|
self._derivations: Dict[str, Derivation] = {}
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Derivation:
|
||||||
|
"""Retrieve a derivation function by key."""
|
||||||
|
return self._derivations[key]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of registered derivation functions."""
|
||||||
|
return len(self._derivations)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return an iterator over the keys of registered functions."""
|
||||||
|
return iter(self._derivations)
|
||||||
|
|
||||||
|
def register(self, key: str, derivation: Derivation) -> None:
|
||||||
|
"""Register a derivation function with a unique key.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The unique key to associate with the derivation function.
|
||||||
|
derivation : Derivation
|
||||||
|
The callable derivation function (takes str, returns str).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a derivation function with the same key is already registered.
|
||||||
|
"""
|
||||||
|
if key in self._derivations:
|
||||||
|
raise KeyError(
|
||||||
|
f"A derivation with the provided key {key} already exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._derivations[key] = derivation
|
||||||
|
|
||||||
|
def get_derivation(self, key: str) -> Derivation:
|
||||||
|
"""Retrieve a derivation function by its registered key.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The key of the derivation function to retrieve.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Derivation
|
||||||
|
The requested derivation function.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If no derivation function with the specified key is registered.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._derivations[key]
|
||||||
|
except KeyError as err:
|
||||||
|
raise KeyError(
|
||||||
|
f"No derivation with key {key} is registered."
|
||||||
|
) from err
|
||||||
|
|
||||||
|
def get_keys(self) -> List[str]:
|
||||||
|
"""Get a list of all registered derivation function keys.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[str]
|
||||||
|
The keys of all registered functions.
|
||||||
|
"""
|
||||||
|
return list(self._derivations.keys())
|
||||||
|
|
||||||
|
def get_derivations(self) -> List[Derivation]:
|
||||||
|
"""Get a list of all registered derivation functions.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[Derivation]
|
||||||
|
The registered derivation function objects.
|
||||||
|
"""
|
||||||
|
return list(self._derivations.values())
|
||||||
|
|
||||||
|
|
||||||
|
default_derivation_registry = DerivationRegistry()
|
||||||
|
"""Global instance of the DerivationRegistry.
|
||||||
|
|
||||||
|
Register custom derivation functions here to make them available by key
|
||||||
|
in `DeriveTagRule` configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_derivation(
|
||||||
|
key: str,
|
||||||
|
import_derivation: bool = False,
|
||||||
|
registry: Optional[DerivationRegistry] = None,
|
||||||
|
):
|
||||||
|
"""Retrieve a derivation function by key, optionally importing it.
|
||||||
|
|
||||||
|
First attempts to find the function in the provided `registry`.
|
||||||
|
If not found and `import_derivation` is True, attempts to dynamically
|
||||||
|
import the function using the `key` as a full Python path
|
||||||
|
(e.g., 'my_module.submodule.my_func').
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The key or Python path of the derivation function.
|
||||||
|
import_derivation : bool, default=False
|
||||||
|
If True, attempt dynamic import if key is not in the registry.
|
||||||
|
registry : DerivationRegistry, optional
|
||||||
|
The registry instance to check first. Defaults to the global
|
||||||
|
`derivation_registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Derivation
|
||||||
|
The requested derivation function.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If the key is not found in the registry and either
|
||||||
|
`import_derivation` is False or the dynamic import fails.
|
||||||
|
ImportError
|
||||||
|
If dynamic import fails specifically due to module not found.
|
||||||
|
AttributeError
|
||||||
|
If dynamic import fails because the function name isn't in the module.
|
||||||
|
"""
|
||||||
|
registry = registry or default_derivation_registry
|
||||||
|
|
||||||
|
if not import_derivation or key in registry:
|
||||||
|
return registry.get_derivation(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
module_path, func_name = key.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
func = getattr(module, func_name)
|
||||||
|
return func
|
||||||
|
except ImportError as err:
|
||||||
|
raise KeyError(
|
||||||
|
f"Unable to load derivation '{key}'. Check the path and ensure "
|
||||||
|
"it points to a valid callable function in an importable module."
|
||||||
|
) from err
|
||||||
|
|
||||||
|
|
||||||
|
TranformationRule = Annotated[
|
||||||
|
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||||
|
Field(discriminator="rule_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_transform_from_rule(
|
||||||
|
rule: TranformationRule,
|
||||||
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
|
term_registry: Optional[TermRegistry] = None,
|
||||||
|
) -> SoundEventTransformation:
|
||||||
|
"""Build a specific SoundEventTransformation function from a rule config.
|
||||||
|
|
||||||
|
Selects the appropriate transformation logic based on the rule's
|
||||||
|
`rule_type`, fetches necessary terms and derivation functions, and
|
||||||
|
returns a partially applied function ready to transform an annotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
|
||||||
|
The configuration object for a single transformation rule.
|
||||||
|
registry : DerivationRegistry, optional
|
||||||
|
The derivation registry to use for `DeriveTagRule`. Defaults to the
|
||||||
|
global `derivation_registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventTransformation
|
||||||
|
A callable that applies the specified rule to a SoundEventAnnotation.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If required term keys or derivation keys are not found.
|
||||||
|
ValueError
|
||||||
|
If the rule has an unknown `rule_type`.
|
||||||
|
ImportError, AttributeError, TypeError
|
||||||
|
If dynamic import of a derivation function fails.
|
||||||
|
"""
|
||||||
|
if rule.rule_type == "replace":
|
||||||
|
source = get_tag_from_info(
|
||||||
|
rule.original,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
target = get_tag_from_info(
|
||||||
|
rule.replacement,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
return partial(replace_tag_transform, source=source, target=target)
|
||||||
|
|
||||||
|
if rule.rule_type == "derive_tag":
|
||||||
|
source_term = get_term_from_key(
|
||||||
|
rule.source_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
target_term = (
|
||||||
|
get_term_from_key(
|
||||||
|
rule.target_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
if rule.target_term_key
|
||||||
|
else source_term
|
||||||
|
)
|
||||||
|
derivation = get_derivation(
|
||||||
|
key=rule.derivation_function,
|
||||||
|
import_derivation=rule.import_derivation,
|
||||||
|
registry=derivation_registry,
|
||||||
|
)
|
||||||
|
return partial(
|
||||||
|
derivation_tag_transform,
|
||||||
|
source_term=source_term,
|
||||||
|
target_term=target_term,
|
||||||
|
derivation=derivation,
|
||||||
|
keep_source=rule.keep_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rule.rule_type == "map_value":
|
||||||
|
source_term = get_term_from_key(
|
||||||
|
rule.source_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
target_term = (
|
||||||
|
get_term_from_key(
|
||||||
|
rule.target_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
if rule.target_term_key
|
||||||
|
else source_term
|
||||||
|
)
|
||||||
|
return partial(
|
||||||
|
map_value_transform,
|
||||||
|
source_term=source_term,
|
||||||
|
target_term=target_term,
|
||||||
|
mapping=rule.value_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle unknown rule type
|
||||||
|
valid_options = ["replace", "derive_tag", "map_value"]
|
||||||
|
# Should be caught by Pydantic validation, but good practice
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
||||||
|
f"Valid options are: {valid_options}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_transformation_from_config(
|
||||||
|
config: TransformConfig,
|
||||||
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
|
term_registry: Optional[TermRegistry] = None,
|
||||||
|
) -> SoundEventTransformation:
|
||||||
|
"""Build a composite transformation function from a TransformConfig.
|
||||||
|
|
||||||
|
Creates a sequence of individual transformation functions based on the
|
||||||
|
rules defined in the configuration. Returns a single function that
|
||||||
|
applies these transformations sequentially to an annotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : TransformConfig
|
||||||
|
The configuration object containing the list of transformation rules.
|
||||||
|
derivation_reg : DerivationRegistry, optional
|
||||||
|
The derivation registry to use when building `DeriveTagRule`
|
||||||
|
transformations. Defaults to the global `derivation_registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventTransformation
|
||||||
|
A single function that applies all configured transformations in order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
transforms = [
|
||||||
|
build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
for rule in config.rules
|
||||||
|
]
|
||||||
|
|
||||||
|
return partial(apply_sequence_of_transforms, transforms=transforms)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sequence_of_transforms(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
transforms: list[SoundEventTransformation],
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
for transform in transforms:
|
||||||
|
sound_event_annotation = transform(sound_event_annotation)
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
|
||||||
|
def load_transformation_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> TransformConfig:
|
||||||
|
"""Load the transformation configuration from a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : str, optional
|
||||||
|
If the transformation configuration is nested under a specific key
|
||||||
|
in the file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
TransformConfig
|
||||||
|
The loaded and validated transformation configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the TransformConfig schema.
|
||||||
|
"""
|
||||||
|
return load_config(path=path, schema=TransformConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def load_transformation_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
|
term_registry: Optional[TermRegistry] = None,
|
||||||
|
) -> SoundEventTransformation:
|
||||||
|
"""Load transformation config from a file and build the final function.
|
||||||
|
|
||||||
|
This is a convenience function that combines loading the configuration
|
||||||
|
and building the final callable transformation function that applies
|
||||||
|
all rules sequentially.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : str, optional
|
||||||
|
If the transformation configuration is nested under a specific key
|
||||||
|
in the file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventTransformation
|
||||||
|
The final composite transformation function ready to be used.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the TransformConfig schema.
|
||||||
|
KeyError
|
||||||
|
If required term keys or derivation keys specified in the config
|
||||||
|
are not found during the build process.
|
||||||
|
ImportError, AttributeError, TypeError
|
||||||
|
If dynamic import of a derivation function specified in the config
|
||||||
|
fails.
|
||||||
|
"""
|
||||||
|
config = load_transformation_config(path=path, field=field)
|
||||||
|
return build_transformation_from_config(
|
||||||
|
config,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_derivation(
|
||||||
|
key: str,
|
||||||
|
derivation: Derivation,
|
||||||
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register a new derivation function in the global registry.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The unique key to associate with the derivation function.
|
||||||
|
derivation : Derivation
|
||||||
|
The callable derivation function (takes str, returns str).
|
||||||
|
derivation_registry : DerivationRegistry, optional
|
||||||
|
The registry instance to register the derivation function with.
|
||||||
|
Defaults to the global `derivation_registry`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a derivation function with the same key is already registered.
|
||||||
|
"""
|
||||||
|
derivation_registry = derivation_registry or default_derivation_registry
|
||||||
|
derivation_registry.register(key, derivation)
|
||||||
@ -33,6 +33,10 @@ from batdetect2.train.losses import (
|
|||||||
SizeLossConfig,
|
SizeLossConfig,
|
||||||
build_loss,
|
build_loss,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
generate_train_example,
|
||||||
|
preprocess_annotations,
|
||||||
|
)
|
||||||
from batdetect2.train.train import (
|
from batdetect2.train.train import (
|
||||||
build_train_dataset,
|
build_train_dataset,
|
||||||
build_train_loader,
|
build_train_loader,
|
||||||
@ -70,12 +74,14 @@ __all__ = [
|
|||||||
"build_trainer",
|
"build_trainer",
|
||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
"build_val_loader",
|
"build_val_loader",
|
||||||
|
"generate_train_example",
|
||||||
"load_full_training_config",
|
"load_full_training_config",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_audio",
|
"mix_audio",
|
||||||
|
"preprocess_annotations",
|
||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
|
|||||||
@ -44,7 +44,7 @@ AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
|||||||
class MixAugmentationConfig(BaseConfig):
|
class MixAugmentationConfig(BaseConfig):
|
||||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||||
|
|
||||||
name: Literal["mix_audio"] = "mix_audio"
|
augmentation_type: Literal["mix_audio"] = "mix_audio"
|
||||||
|
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
"""Probability of applying this augmentation to an example."""
|
"""Probability of applying this augmentation to an example."""
|
||||||
@ -140,7 +140,7 @@ def combine_clip_annotations(
|
|||||||
class EchoAugmentationConfig(BaseConfig):
|
class EchoAugmentationConfig(BaseConfig):
|
||||||
"""Configuration for adding synthetic echo/reverb."""
|
"""Configuration for adding synthetic echo/reverb."""
|
||||||
|
|
||||||
name: Literal["add_echo"] = "add_echo"
|
augmentation_type: Literal["add_echo"] = "add_echo"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_delay: float = 0.005
|
max_delay: float = 0.005
|
||||||
min_weight: float = 0.0
|
min_weight: float = 0.0
|
||||||
@ -187,7 +187,7 @@ def add_echo(
|
|||||||
class VolumeAugmentationConfig(BaseConfig):
|
class VolumeAugmentationConfig(BaseConfig):
|
||||||
"""Configuration for random volume scaling of the spectrogram."""
|
"""Configuration for random volume scaling of the spectrogram."""
|
||||||
|
|
||||||
name: Literal["scale_volume"] = "scale_volume"
|
augmentation_type: Literal["scale_volume"] = "scale_volume"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
min_scaling: float = 0.0
|
min_scaling: float = 0.0
|
||||||
max_scaling: float = 2.0
|
max_scaling: float = 2.0
|
||||||
@ -214,7 +214,7 @@ def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
class WarpAugmentationConfig(BaseConfig):
|
class WarpAugmentationConfig(BaseConfig):
|
||||||
name: Literal["warp"] = "warp"
|
augmentation_type: Literal["warp"] = "warp"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
delta: float = 0.04
|
delta: float = 0.04
|
||||||
|
|
||||||
@ -296,7 +296,7 @@ def warp_spectrogram(
|
|||||||
|
|
||||||
|
|
||||||
class TimeMaskAugmentationConfig(BaseConfig):
|
class TimeMaskAugmentationConfig(BaseConfig):
|
||||||
name: Literal["mask_time"] = "mask_time"
|
augmentation_type: Literal["mask_time"] = "mask_time"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_perc: float = 0.05
|
max_perc: float = 0.05
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
@ -353,7 +353,7 @@ def mask_time(
|
|||||||
|
|
||||||
|
|
||||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||||
name: Literal["mask_freq"] = "mask_freq"
|
augmentation_type: Literal["mask_freq"] = "mask_freq"
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
max_perc: float = 0.10
|
max_perc: float = 0.10
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
@ -414,7 +414,7 @@ AudioAugmentationConfig = Annotated[
|
|||||||
MixAugmentationConfig,
|
MixAugmentationConfig,
|
||||||
EchoAugmentationConfig,
|
EchoAugmentationConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="augmentation_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ SpectrogramAugmentationConfig = Annotated[
|
|||||||
FrequencyMaskAugmentationConfig,
|
FrequencyMaskAugmentationConfig,
|
||||||
TimeMaskAugmentationConfig,
|
TimeMaskAugmentationConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="augmentation_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
@ -437,7 +437,7 @@ AugmentationConfig = Annotated[
|
|||||||
FrequencyMaskAugmentationConfig,
|
FrequencyMaskAugmentationConfig,
|
||||||
TimeMaskAugmentationConfig,
|
TimeMaskAugmentationConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="augmentation_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of individual augmentation config."""
|
"""Type alias for the discriminated union of individual augmentation config."""
|
||||||
|
|
||||||
@ -485,7 +485,7 @@ def build_augmentation_from_config(
|
|||||||
audio_source: Optional[AudioSource] = None,
|
audio_source: Optional[AudioSource] = None,
|
||||||
) -> Optional[Augmentation]:
|
) -> Optional[Augmentation]:
|
||||||
"""Factory function to build a single augmentation from its config."""
|
"""Factory function to build a single augmentation from its config."""
|
||||||
if config.name == "mix_audio":
|
if config.augmentation_type == "mix_audio":
|
||||||
if audio_source is None:
|
if audio_source is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Mix audio augmentation ('mix_audio') requires an "
|
"Mix audio augmentation ('mix_audio') requires an "
|
||||||
@ -500,31 +500,31 @@ def build_augmentation_from_config(
|
|||||||
max_weight=config.max_weight,
|
max_weight=config.max_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "add_echo":
|
if config.augmentation_type == "add_echo":
|
||||||
return AddEcho(
|
return AddEcho(
|
||||||
max_delay=int(config.max_delay * samplerate),
|
max_delay=int(config.max_delay * samplerate),
|
||||||
min_weight=config.min_weight,
|
min_weight=config.min_weight,
|
||||||
max_weight=config.max_weight,
|
max_weight=config.max_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "scale_volume":
|
if config.augmentation_type == "scale_volume":
|
||||||
return ScaleVolume(
|
return ScaleVolume(
|
||||||
max_scaling=config.max_scaling,
|
max_scaling=config.max_scaling,
|
||||||
min_scaling=config.min_scaling,
|
min_scaling=config.min_scaling,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "warp":
|
if config.augmentation_type == "warp":
|
||||||
return WarpSpectrogram(
|
return WarpSpectrogram(
|
||||||
delta=config.delta,
|
delta=config.delta,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "mask_time":
|
if config.augmentation_type == "mask_time":
|
||||||
return MaskTime(
|
return MaskTime(
|
||||||
max_perc=config.max_perc,
|
max_perc=config.max_perc,
|
||||||
max_masks=config.max_masks,
|
max_masks=config.max_masks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "mask_freq":
|
if config.augmentation_type == "mask_freq":
|
||||||
return MaskFrequency(
|
return MaskFrequency(
|
||||||
max_perc=config.max_perc,
|
max_perc=config.max_perc,
|
||||||
max_masks=config.max_masks,
|
max_masks=config.max_masks,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
@ -74,6 +75,7 @@ class TrainingConfig(BaseConfig):
|
|||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,9 @@
|
|||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
from soundevent.data import PathLike
|
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model
|
||||||
from batdetect2.train.config import FullTrainingConfig
|
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
from batdetect2.typing import ModelOutput, TrainExample
|
from batdetect2.typing import ModelOutput, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -21,28 +16,22 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FullTrainingConfig,
|
model: Model,
|
||||||
|
loss: torch.nn.Module,
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
model: Optional[Model] = None,
|
|
||||||
loss: Optional[torch.nn.Module] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.save_hyperparameters(logger=False)
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
if loss is None:
|
|
||||||
loss = build_loss(self.config.train.loss)
|
|
||||||
|
|
||||||
if model is None:
|
|
||||||
model = build_model(self.config)
|
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
|
return self.model(spec)
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample):
|
def training_step(self, batch: TrainExample):
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
@ -70,10 +59,3 @@ class TrainingModule(L.LightningModule):
|
|||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_checkpoint(
|
|
||||||
path: PathLike,
|
|
||||||
) -> Tuple[Model, FullTrainingConfig]:
|
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
|
||||||
return module.model, module.config
|
|
||||||
|
|||||||
@ -5,11 +5,10 @@ import numpy as np
|
|||||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: str = "outputs"
|
DEFAULT_LOGS_DIR: str = "logs"
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseConfig):
|
class DVCLiveConfig(BaseConfig):
|
||||||
@ -32,7 +31,7 @@ class CSVLoggerConfig(BaseConfig):
|
|||||||
class TensorBoardLoggerConfig(BaseConfig):
|
class TensorBoardLoggerConfig(BaseConfig):
|
||||||
logger_type: Literal["tensorboard"] = "tensorboard"
|
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||||
save_dir: str = DEFAULT_LOGS_DIR
|
save_dir: str = DEFAULT_LOGS_DIR
|
||||||
name: Optional[str] = "logs"
|
name: Optional[str] = "default"
|
||||||
version: Optional[str] = None
|
version: Optional[str] = None
|
||||||
log_graph: bool = False
|
log_graph: bool = False
|
||||||
|
|
||||||
@ -58,10 +57,7 @@ LoggerConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(
|
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||||
config: DVCLiveConfig,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> Logger:
|
|
||||||
try:
|
try:
|
||||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
@ -72,7 +68,7 @@ def create_dvclive_logger(
|
|||||||
) from error
|
) from error
|
||||||
|
|
||||||
return DVCLiveLogger(
|
return DVCLiveLogger(
|
||||||
dir=log_dir if log_dir is not None else config.dir,
|
dir=config.dir,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
@ -80,38 +76,29 @@ def create_dvclive_logger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_csv_logger(
|
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
||||||
config: CSVLoggerConfig,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> Logger:
|
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
return CSVLogger(
|
return CSVLogger(
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=config.save_dir,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_tensorboard_logger(
|
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
||||||
config: TensorBoardLoggerConfig,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> Logger:
|
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
return TensorBoardLogger(
|
return TensorBoardLogger(
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=config.save_dir,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
log_graph=config.log_graph,
|
log_graph=config.log_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_mlflow_logger(
|
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
||||||
config: MLFlowLoggerConfig,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> Logger:
|
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
@ -124,7 +111,7 @@ def create_mlflow_logger(
|
|||||||
return MLFlowLogger(
|
return MLFlowLogger(
|
||||||
experiment_name=config.experiment_name,
|
experiment_name=config.experiment_name,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=config.save_dir,
|
||||||
tracking_uri=config.tracking_uri,
|
tracking_uri=config.tracking_uri,
|
||||||
tags=config.tags,
|
tags=config.tags,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
@ -139,10 +126,7 @@ LOGGER_FACTORY = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_logger(
|
def build_logger(config: LoggerConfig) -> Logger:
|
||||||
config: LoggerConfig,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> Logger:
|
|
||||||
"""
|
"""
|
||||||
Creates a logger instance from a validated Pydantic config object.
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
"""
|
"""
|
||||||
@ -157,7 +141,7 @@ def build_logger(
|
|||||||
|
|
||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config, log_dir=log_dir)
|
return creation_func(config)
|
||||||
|
|
||||||
|
|
||||||
def get_image_plotter(logger: Logger):
|
def get_image_plotter(logger: Logger):
|
||||||
|
|||||||
243
src/batdetect2/train/preprocess.py
Normal file
243
src/batdetect2/train/preprocess.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
"""Preprocesses datasets for BatDetect2 model training."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.data.datasets import Dataset
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
|
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||||
|
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"preprocess_annotations",
|
||||||
|
"generate_train_example",
|
||||||
|
"preprocess_dataset",
|
||||||
|
"TrainPreprocessConfig",
|
||||||
|
"load_train_preprocessing_config",
|
||||||
|
"save_preprocessed_example",
|
||||||
|
"load_preprocessed_example",
|
||||||
|
]
|
||||||
|
|
||||||
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||||
|
"""Type alias for a function that generates an output filename."""
|
||||||
|
|
||||||
|
|
||||||
|
class TrainPreprocessConfig(BaseConfig):
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_preprocessing_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> TrainPreprocessConfig:
|
||||||
|
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_dataset(
|
||||||
|
dataset: Dataset,
|
||||||
|
config: TrainPreprocessConfig,
|
||||||
|
output: Path,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
targets = build_targets(config=config.targets)
|
||||||
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
targets,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
config=config.labels,
|
||||||
|
)
|
||||||
|
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||||
|
|
||||||
|
if not output.exists():
|
||||||
|
logger.debug("Creating directory {directory}", directory=output)
|
||||||
|
output.mkdir(parents=True)
|
||||||
|
|
||||||
|
preprocess_annotations(
|
||||||
|
dataset,
|
||||||
|
output_dir=output,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
labeller=labeller,
|
||||||
|
max_workers=max_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Example(TypedDict):
|
||||||
|
audio: torch.Tensor
|
||||||
|
spectrogram: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def generate_train_example(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
labeller: ClipLabeller,
|
||||||
|
) -> PreprocessedExample:
|
||||||
|
"""Generate a complete training example for one annotation."""
|
||||||
|
wave = torch.tensor(
|
||||||
|
audio_loader.load_clip(clip_annotation.clip)
|
||||||
|
).unsqueeze(0)
|
||||||
|
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
|
||||||
|
heatmaps = labeller(clip_annotation, spectrogram)
|
||||||
|
return PreprocessedExample(
|
||||||
|
audio=wave,
|
||||||
|
spectrogram=spectrogram,
|
||||||
|
detection_heatmap=heatmaps.detection,
|
||||||
|
class_heatmap=heatmaps.classes,
|
||||||
|
size_heatmap=heatmaps.size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clips: Dataset,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
labeller: ClipLabeller,
|
||||||
|
filename_fn: FilenameFn,
|
||||||
|
output_dir: Path,
|
||||||
|
force: bool = False,
|
||||||
|
):
|
||||||
|
self.clips = clips
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.labeller = labeller
|
||||||
|
self.filename_fn = filename_fn
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.force = force
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> int:
|
||||||
|
clip_annotation = self.clips[idx]
|
||||||
|
|
||||||
|
filename = self.filename_fn(clip_annotation)
|
||||||
|
|
||||||
|
path = self.output_dir / filename
|
||||||
|
|
||||||
|
if path.exists() and not self.force:
|
||||||
|
return idx
|
||||||
|
|
||||||
|
if not path.parent.exists():
|
||||||
|
path.parent.mkdir()
|
||||||
|
|
||||||
|
example = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
labeller=self.labeller,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_preprocessed_example(example, clip_annotation, path)
|
||||||
|
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.clips)
|
||||||
|
|
||||||
|
|
||||||
|
def save_preprocessed_example(
|
||||||
|
example: PreprocessedExample,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
path: data.PathLike,
|
||||||
|
) -> None:
|
||||||
|
np.savez_compressed(
|
||||||
|
path,
|
||||||
|
audio=example.audio.numpy(),
|
||||||
|
spectrogram=example.spectrogram.numpy(),
|
||||||
|
detection_heatmap=example.detection_heatmap.numpy(),
|
||||||
|
class_heatmap=example.class_heatmap.numpy(),
|
||||||
|
size_heatmap=example.size_heatmap.numpy(),
|
||||||
|
clip_annotation=clip_annotation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||||
|
item = np.load(path, mmap_mode="r+")
|
||||||
|
return PreprocessedExample(
|
||||||
|
audio=torch.tensor(item["audio"]),
|
||||||
|
spectrogram=torch.tensor(item["spectrogram"]),
|
||||||
|
size_heatmap=torch.tensor(item["size_heatmap"]),
|
||||||
|
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
||||||
|
class_heatmap=torch.tensor(item["class_heatmap"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_preprocessed_files(
|
||||||
|
directory: data.PathLike, extension: str = ".npz"
|
||||||
|
) -> List[Path]:
|
||||||
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||||
|
"""Generate a default output filename based on the annotation UUID."""
|
||||||
|
return f"{clip_annotation.uuid}"
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_annotations(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
output_dir: data.PathLike,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
labeller: ClipLabeller,
|
||||||
|
filename_fn: FilenameFn = _get_filename,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
if not output_dir.is_dir():
|
||||||
|
logger.info(
|
||||||
|
"Creating output directory: {output_dir}", output_dir=output_dir
|
||||||
|
)
|
||||||
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
|
||||||
|
num_annotations=len(clip_annotations),
|
||||||
|
max_workers=max_workers or "all available",
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_workers is None:
|
||||||
|
max_workers = os.cpu_count() or 0
|
||||||
|
|
||||||
|
dataset = PreprocessingDataset(
|
||||||
|
clips=list(clip_annotations),
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
labeller=labeller,
|
||||||
|
output_dir=Path(output_dir),
|
||||||
|
filename_fn=filename_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=max_workers,
|
||||||
|
prefetch_factor=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in tqdm(loader, total=len(dataset)):
|
||||||
|
pass
|
||||||
@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
|
|||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
RandomAudioSource,
|
RandomAudioSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
@ -28,6 +28,7 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
|||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
@ -53,21 +54,19 @@ def train(
|
|||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
):
|
):
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
targets = build_targets(config.targets)
|
model = build_model(config=config)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(config.preprocess)
|
trainer = build_trainer(config, targets=model.targets)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||||
|
|
||||||
labeller = build_clip_labeler(
|
labeller = build_clip_labeler(
|
||||||
targets,
|
model.targets,
|
||||||
min_freq=preprocessor.min_freq,
|
min_freq=model.preprocessor.min_freq,
|
||||||
max_freq=preprocessor.max_freq,
|
max_freq=model.preprocessor.max_freq,
|
||||||
config=config.train.labels,
|
config=config.train.labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,7 +74,7 @@ def train(
|
|||||||
train_annotations,
|
train_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=build_preprocessor(config.preprocess),
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
@ -85,7 +84,7 @@ def train(
|
|||||||
val_annotations,
|
val_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=build_preprocessor(config.preprocess),
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
@ -98,17 +97,11 @@ def train(
|
|||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
else:
|
else:
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
|
model,
|
||||||
config,
|
config,
|
||||||
t_max=config.train.t_max * len(train_dataloader),
|
batches_per_epoch=len(train_dataloader),
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = build_trainer(
|
|
||||||
config,
|
|
||||||
targets=targets,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
log_dir=log_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
@ -119,14 +112,16 @@ def train(
|
|||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
config: Optional[FullTrainingConfig] = None,
|
model: Model,
|
||||||
t_max: int = 200,
|
config: FullTrainingConfig,
|
||||||
|
batches_per_epoch: int,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
config = config or FullTrainingConfig()
|
loss = build_loss(config=config.train.loss)
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
config=config,
|
model=model,
|
||||||
|
loss=loss,
|
||||||
learning_rate=config.train.learning_rate,
|
learning_rate=config.train.learning_rate,
|
||||||
t_max=t_max,
|
t_max=config.train.t_max * batches_per_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -134,14 +129,10 @@ def build_trainer_callbacks(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: EvaluationConfig,
|
config: EvaluationConfig,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> List[Callback]:
|
) -> List[Callback]:
|
||||||
if checkpoint_dir is None:
|
|
||||||
checkpoint_dir = "outputs/checkpoints"
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath="outputs/checkpoints",
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
monitor="total_loss/val",
|
monitor="total_loss/val",
|
||||||
),
|
),
|
||||||
@ -162,22 +153,15 @@ def build_trainer_callbacks(
|
|||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: FullTrainingConfig,
|
conf: FullTrainingConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.train.trainer
|
trainer_conf = conf.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
|
train_logger = build_logger(conf.train.logger)
|
||||||
|
|
||||||
train_logger.log_hyperparams(
|
train_logger.log_hyperparams(conf.model_dump(mode="json"))
|
||||||
conf.model_dump(
|
|
||||||
mode="json",
|
|
||||||
exclude_none=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
@ -186,7 +170,6 @@ def build_trainer(
|
|||||||
targets,
|
targets,
|
||||||
config=conf.evaluation,
|
config=conf.evaluation,
|
||||||
preprocessor=build_preprocessor(conf.preprocess),
|
preprocessor=build_preprocessor(conf.preprocess),
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,8 @@ class TargetProtocol(Protocol):
|
|||||||
This protocol outlines the standard attributes and methods for an object
|
This protocol outlines the standard attributes and methods for an object
|
||||||
that encapsulates the complete, configured process for handling sound event
|
that encapsulates the complete, configured process for handling sound event
|
||||||
annotations (both tags and geometry). It defines how to:
|
annotations (both tags and geometry). It defines how to:
|
||||||
- Select relevant annotations.
|
- Filter relevant annotations.
|
||||||
|
- Transform annotation tags.
|
||||||
- Encode an annotation into a specific target class name.
|
- Encode an annotation into a specific target class name.
|
||||||
- Decode a class name back into representative tags.
|
- Decode a class name back into representative tags.
|
||||||
- Extract a target reference position from an annotation's geometry (ROI).
|
- Extract a target reference position from an annotation's geometry (ROI).
|
||||||
@ -120,6 +121,26 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
"""Apply tag transformations to an annotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation whose tags should be transformed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.SoundEventAnnotation
|
||||||
|
A new annotation object with the transformed tags. Implementations
|
||||||
|
should return the original annotation object if no transformations
|
||||||
|
were configured.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
def encode_class(
|
def encode_class(
|
||||||
self,
|
self,
|
||||||
sound_event: data.SoundEventAnnotation,
|
sound_event: data.SoundEventAnnotation,
|
||||||
@ -227,70 +248,3 @@ class TargetProtocol(Protocol):
|
|||||||
if reconstruction fails based on the configured position type.
|
if reconstruction fails based on the configured position type.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class ROITargetMapper(Protocol):
|
|
||||||
"""Protocol defining the interface for ROI-to-target mapping.
|
|
||||||
|
|
||||||
Specifies the `encode` and `decode` methods required for converting a
|
|
||||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
|
||||||
position and a size vector) and for recovering an approximate ROI from that
|
|
||||||
representation.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
dimension_names : List[str]
|
|
||||||
A list containing the names of the dimensions in the `Size` array
|
|
||||||
returned by `encode` and expected by `decode`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dimension_names: List[str]
|
|
||||||
|
|
||||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
|
||||||
"""Encode a SoundEvent's geometry into a position and size.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEvent
|
|
||||||
The input sound event, which must have a geometry attribute.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[Position, Size]
|
|
||||||
A tuple containing:
|
|
||||||
- The reference position as (time, frequency) coordinates.
|
|
||||||
- A NumPy array with the calculated size dimensions.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the sound event does not have a geometry.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
|
||||||
"""Decode a position and size back into a geometric ROI.
|
|
||||||
|
|
||||||
Performs the inverse mapping: takes a reference position and size
|
|
||||||
dimensions and reconstructs a geometric representation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
position : Position
|
|
||||||
The reference position (time, frequency).
|
|
||||||
size : Size
|
|
||||||
NumPy array containing the size dimensions, matching the order
|
|
||||||
and meaning specified by `dimension_names`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Geometry
|
|
||||||
The reconstructed geometry, typically a `BoundingBox`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the `size` array has an unexpected shape or if reconstruction
|
|
||||||
fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|||||||
@ -15,10 +15,13 @@ from batdetect2.preprocess import build_preprocessor
|
|||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
|
TermRegistry,
|
||||||
build_targets,
|
build_targets,
|
||||||
call_type,
|
call_type,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.classes import TargetClassConfig
|
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
||||||
|
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
||||||
|
from batdetect2.targets.terms import TagInfo
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
@ -352,6 +355,18 @@ def create_annotation_project():
|
|||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_term_registry() -> TermRegistry:
|
||||||
|
"""Fixture for a sample TermRegistry."""
|
||||||
|
registry = TermRegistry()
|
||||||
|
registry.add_custom_term("class")
|
||||||
|
registry.add_custom_term("order")
|
||||||
|
registry.add_custom_term("species")
|
||||||
|
registry.add_custom_term("call_type")
|
||||||
|
registry.add_custom_term("quality")
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_preprocessor() -> PreprocessorProtocol:
|
def sample_preprocessor() -> PreprocessorProtocol:
|
||||||
return build_preprocessor()
|
return build_preprocessor()
|
||||||
@ -363,45 +378,56 @@ def sample_audio_loader() -> AudioLoader:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def bat_tag() -> data.Tag:
|
def bat_tag() -> TagInfo:
|
||||||
return data.Tag(key="class", value="bat")
|
return TagInfo(key="class", value="bat")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def noise_tag() -> data.Tag:
|
def noise_tag() -> TagInfo:
|
||||||
return data.Tag(key="class", value="noise")
|
return TagInfo(key="class", value="noise")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def myomyo_tag() -> data.Tag:
|
def myomyo_tag() -> TagInfo:
|
||||||
return data.Tag(key="species", value="Myotis myotis")
|
return TagInfo(key="species", value="Myotis myotis")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pippip_tag() -> data.Tag:
|
def pippip_tag() -> TagInfo:
|
||||||
return data.Tag(key="species", value="Pipistrellus pipistrellus")
|
return TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_target_config(
|
def sample_target_config(
|
||||||
bat_tag: data.Tag,
|
sample_term_registry: TermRegistry,
|
||||||
myomyo_tag: data.Tag,
|
bat_tag: TagInfo,
|
||||||
pippip_tag: data.Tag,
|
noise_tag: TagInfo,
|
||||||
|
myomyo_tag: TagInfo,
|
||||||
|
pippip_tag: TagInfo,
|
||||||
) -> TargetConfig:
|
) -> TargetConfig:
|
||||||
return TargetConfig(
|
return TargetConfig(
|
||||||
detection_target=TargetClassConfig(name="bat", tags=[bat_tag]),
|
filtering=FilterConfig(
|
||||||
classification_targets=[
|
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
|
||||||
TargetClassConfig(name="pippip", tags=[pippip_tag]),
|
),
|
||||||
TargetClassConfig(name="myomyo", tags=[myomyo_tag]),
|
classes=ClassesConfig(
|
||||||
],
|
classes=[
|
||||||
|
TargetClass(name="pippip", tags=[pippip_tag]),
|
||||||
|
TargetClass(name="myomyo", tags=[myomyo_tag]),
|
||||||
|
],
|
||||||
|
generic_class=[bat_tag],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_targets(
|
def sample_targets(
|
||||||
sample_target_config: TargetConfig,
|
sample_target_config: TargetConfig,
|
||||||
|
sample_term_registry: TermRegistry,
|
||||||
) -> TargetProtocol:
|
) -> TargetProtocol:
|
||||||
return build_targets(sample_target_config)
|
return build_targets(
|
||||||
|
sample_target_config,
|
||||||
|
term_registry=sample_term_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -417,8 +443,10 @@ def sample_labeller(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_clipper() -> ClipperProtocol:
|
def sample_clipper(
|
||||||
return build_clipper()
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
) -> ClipperProtocol:
|
||||||
|
return build_clipper(preprocessor=sample_preprocessor)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -1,516 +0,0 @@
|
|||||||
import textwrap
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
SoundEventConditionConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_condition_from_str(content):
|
|
||||||
content = textwrap.dedent(content)
|
|
||||||
content = yaml.safe_load(content)
|
|
||||||
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
|
||||||
return build_sound_event_condition(config)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_tag(sound_event: data.SoundEvent):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
""")
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
|
||||||
)
|
|
||||||
assert not condition(sound_event_annotation)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_all_tags(sound_event: data.SoundEvent):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: has_all_tags
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- key: event
|
|
||||||
value: Echolocation
|
|
||||||
""")
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert not condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert not condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
|
||||||
data.Tag(key="sex", value="Female"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert condition(sound_event_annotation)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_any_tags(sound_event: data.SoundEvent):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: has_any_tag
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- key: event
|
|
||||||
value: Echolocation
|
|
||||||
""")
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Eptesicus fuscus"), # type: ignore
|
|
||||||
data.Tag(key="event", value="Social"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert not condition(sound_event_annotation)
|
|
||||||
|
|
||||||
|
|
||||||
def test_not(sound_event: data.SoundEvent):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: not
|
|
||||||
condition:
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
""")
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert not condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
|
||||||
)
|
|
||||||
assert condition(sound_event_annotation)
|
|
||||||
|
|
||||||
sound_event_annotation = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Myotis myotis"), # type: ignore
|
|
||||||
data.Tag(key="event", value="Echolocation"), # type: ignore
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert not condition(sound_event_annotation)
|
|
||||||
|
|
||||||
|
|
||||||
def test_duration(recording: data.Recording):
|
|
||||||
se1 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
|
|
||||||
),
|
|
||||||
)
|
|
||||||
se2 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
|
|
||||||
),
|
|
||||||
)
|
|
||||||
se3 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: duration
|
|
||||||
operator: lt
|
|
||||||
seconds: 2
|
|
||||||
""")
|
|
||||||
assert condition(se1)
|
|
||||||
assert not condition(se2)
|
|
||||||
assert not condition(se3)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: duration
|
|
||||||
operator: lte
|
|
||||||
seconds: 2
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert condition(se1)
|
|
||||||
assert condition(se2)
|
|
||||||
assert not condition(se3)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: duration
|
|
||||||
operator: gt
|
|
||||||
seconds: 2
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se1)
|
|
||||||
assert not condition(se2)
|
|
||||||
assert condition(se3)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: duration
|
|
||||||
operator: gte
|
|
||||||
seconds: 2
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se1)
|
|
||||||
assert condition(se2)
|
|
||||||
assert condition(se3)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: duration
|
|
||||||
operator: eq
|
|
||||||
seconds: 2
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se1)
|
|
||||||
assert condition(se2)
|
|
||||||
assert not condition(se3)
|
|
||||||
|
|
||||||
|
|
||||||
def test_frequency(recording: data.Recording):
|
|
||||||
se12 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
se13 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
se14 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
se24 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
se34 = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: high
|
|
||||||
operator: lt
|
|
||||||
hertz: 300
|
|
||||||
""")
|
|
||||||
assert condition(se12)
|
|
||||||
assert not condition(se13)
|
|
||||||
assert not condition(se14)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: high
|
|
||||||
operator: lte
|
|
||||||
hertz: 300
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert condition(se12)
|
|
||||||
assert condition(se13)
|
|
||||||
assert not condition(se14)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: high
|
|
||||||
operator: gt
|
|
||||||
hertz: 300
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se12)
|
|
||||||
assert not condition(se13)
|
|
||||||
assert condition(se14)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: high
|
|
||||||
operator: gte
|
|
||||||
hertz: 300
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se12)
|
|
||||||
assert condition(se13)
|
|
||||||
assert condition(se14)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: high
|
|
||||||
operator: eq
|
|
||||||
hertz: 300
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se12)
|
|
||||||
assert condition(se13)
|
|
||||||
assert not condition(se14)
|
|
||||||
|
|
||||||
# LOW
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: lt
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
assert condition(se14)
|
|
||||||
assert not condition(se24)
|
|
||||||
assert not condition(se34)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: lte
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert condition(se14)
|
|
||||||
assert condition(se24)
|
|
||||||
assert not condition(se34)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: gt
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se14)
|
|
||||||
assert not condition(se24)
|
|
||||||
assert condition(se34)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: gte
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se14)
|
|
||||||
assert condition(se24)
|
|
||||||
assert condition(se34)
|
|
||||||
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: eq
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
|
|
||||||
assert not condition(se14)
|
|
||||||
assert condition(se24)
|
|
||||||
assert not condition(se34)
|
|
||||||
|
|
||||||
|
|
||||||
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: eq
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 3]),
|
|
||||||
recording=recording,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeStamp(coordinates=3),
|
|
||||||
recording=recording,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_tags_fails_if_empty():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
build_condition_from_str("""
|
|
||||||
name: has_tags
|
|
||||||
tags: []
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
def test_frequency_is_false_if_no_geometry(recording: data.Recording):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: eq
|
|
||||||
hertz: 200
|
|
||||||
""")
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
|
|
||||||
def test_duration_is_false_if_no_geometry(recording: data.Recording):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: duration
|
|
||||||
operator: eq
|
|
||||||
seconds: 1
|
|
||||||
""")
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(geometry=None, recording=recording)
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_of(recording: data.Recording):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: all_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- name: duration
|
|
||||||
operator: lt
|
|
||||||
seconds: 1
|
|
||||||
""")
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
|
||||||
recording=recording,
|
|
||||||
),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert condition(se)
|
|
||||||
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
|
||||||
recording=recording,
|
|
||||||
),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
|
||||||
recording=recording,
|
|
||||||
),
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
|
|
||||||
def test_any_of(recording: data.Recording):
|
|
||||||
condition = build_condition_from_str("""
|
|
||||||
name: any_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- name: duration
|
|
||||||
operator: lt
|
|
||||||
seconds: 1
|
|
||||||
""")
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
|
||||||
recording=recording,
|
|
||||||
),
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
|
||||||
)
|
|
||||||
assert not condition(se)
|
|
||||||
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
|
||||||
recording=recording,
|
|
||||||
),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert condition(se)
|
|
||||||
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
|
||||||
recording=recording,
|
|
||||||
),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
|
||||||
)
|
|
||||||
assert condition(se)
|
|
||||||
|
|
||||||
se = data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
|
||||||
recording=recording,
|
|
||||||
),
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")], # type: ignore
|
|
||||||
)
|
|
||||||
assert condition(se)
|
|
||||||
@ -3,15 +3,26 @@ from typing import Callable
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.terms import get_term
|
from soundevent.terms import get_term
|
||||||
|
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
TargetClassConfig,
|
DEFAULT_SPECIES_LIST,
|
||||||
|
ClassesConfig,
|
||||||
|
TargetClass,
|
||||||
|
_get_default_class_name,
|
||||||
|
_get_default_classes,
|
||||||
|
build_generic_class_tags,
|
||||||
build_sound_event_decoder,
|
build_sound_event_decoder,
|
||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
|
is_target_class,
|
||||||
|
load_classes_config,
|
||||||
|
load_decoder_from_config,
|
||||||
|
load_encoder_from_config,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.terms import TagInfo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -22,8 +33,8 @@ def sample_annotation(
|
|||||||
return data.SoundEventAnnotation(
|
return data.SoundEventAnnotation(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key="species", value="Pipistrellus pipistrellus"),
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
data.Tag(key="quality", value="Good"),
|
data.Tag(key="quality", value="Good"), # type: ignore
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -40,71 +51,291 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
|||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
def test_get_class_names_from_config():
|
def test_target_class_creation():
|
||||||
target_class1 = TargetClassConfig(
|
target_class = TargetClass(
|
||||||
name="pippip",
|
name="pippip",
|
||||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||||
)
|
)
|
||||||
target_class2 = TargetClassConfig(
|
assert target_class.name == "pippip"
|
||||||
|
assert target_class.tags[0].key == "species"
|
||||||
|
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
|
||||||
|
assert target_class.match_type == "all"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes_config_creation():
|
||||||
|
target_class = TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
config = ClassesConfig(classes=[target_class])
|
||||||
|
assert len(config.classes) == 1
|
||||||
|
assert config.classes[0].name == "pippip"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes_config_unique_names():
|
||||||
|
target_class1 = TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
target_class2 = TargetClass(
|
||||||
name="myodau",
|
name="myodau",
|
||||||
tags=[data.Tag(key="species", value="Myotis daubentonii")],
|
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||||
)
|
)
|
||||||
names = get_class_names_from_config([target_class1, target_class2])
|
ClassesConfig(classes=[target_class1, target_class2]) # No error
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes_config_non_unique_names():
|
||||||
|
target_class1 = TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
target_class2 = TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||||
|
)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ClassesConfig(classes=[target_class1, target_class2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
|
||||||
|
yaml_content = """
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
"""
|
||||||
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
|
config = load_classes_config(temp_yaml_path)
|
||||||
|
assert len(config.classes) == 1
|
||||||
|
assert config.classes[0].name == "pippip"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
||||||
|
yaml_content = """
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis daubentonii
|
||||||
|
"""
|
||||||
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
load_classes_config(temp_yaml_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_target_class_match_all(
|
||||||
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
|
):
|
||||||
|
tags = {
|
||||||
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
|
data.Tag(key="quality", value="Good"), # type: ignore
|
||||||
|
}
|
||||||
|
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||||
|
|
||||||
|
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||||
|
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||||
|
|
||||||
|
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||||
|
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_target_class_match_any(
|
||||||
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
|
):
|
||||||
|
tags = {
|
||||||
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
|
data.Tag(key="quality", value="Good"), # type: ignore
|
||||||
|
}
|
||||||
|
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||||
|
|
||||||
|
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||||
|
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||||
|
|
||||||
|
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||||
|
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_class_names_from_config():
|
||||||
|
target_class1 = TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
target_class2 = TargetClass(
|
||||||
|
name="myodau",
|
||||||
|
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||||
|
)
|
||||||
|
config = ClassesConfig(classes=[target_class1, target_class2])
|
||||||
|
names = get_class_names_from_config(config)
|
||||||
assert names == ["pippip", "myodau"]
|
assert names == ["pippip", "myodau"]
|
||||||
|
|
||||||
|
|
||||||
def test_build_encoder_from_config(
|
def test_build_encoder_from_config(
|
||||||
sample_annotation: data.SoundEventAnnotation,
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
):
|
):
|
||||||
classes = [
|
config = ClassesConfig(
|
||||||
TargetClassConfig(
|
classes=[
|
||||||
name="pippip",
|
TargetClass(
|
||||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
name="pippip",
|
||||||
)
|
tags=[
|
||||||
]
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
encoder = build_sound_event_encoder(classes)
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
encoder = build_sound_event_encoder(config)
|
||||||
result = encoder(sample_annotation)
|
result = encoder(sample_annotation)
|
||||||
assert result == "pippip"
|
assert result == "pippip"
|
||||||
|
|
||||||
classes = []
|
config = ClassesConfig(classes=[])
|
||||||
encoder = build_sound_event_encoder(classes)
|
encoder = build_sound_event_encoder(config)
|
||||||
result = encoder(sample_annotation)
|
result = encoder(sample_annotation)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_encoder_from_config_valid(
|
||||||
|
sample_annotation: data.SoundEventAnnotation,
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
"""
|
||||||
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
|
encoder = load_encoder_from_config(temp_yaml_path)
|
||||||
|
# We cannot directly compare the function, so we test it.
|
||||||
|
result = encoder(sample_annotation) # type: ignore
|
||||||
|
assert result == "pippip"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_encoder_from_config_invalid(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: invalid_key
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
"""
|
||||||
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
load_encoder_from_config(temp_yaml_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_class_name():
|
||||||
|
assert _get_default_class_name("Myotis daubentonii") == "myodau"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_classes():
|
||||||
|
default_classes = _get_default_classes()
|
||||||
|
assert len(default_classes) == len(DEFAULT_SPECIES_LIST)
|
||||||
|
first_class = default_classes[0]
|
||||||
|
assert isinstance(first_class, TargetClass)
|
||||||
|
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
|
||||||
|
assert first_class.tags[0].key == "class"
|
||||||
|
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
||||||
|
|
||||||
|
|
||||||
def test_build_decoder_from_config():
|
def test_build_decoder_from_config():
|
||||||
classes = [
|
config = ClassesConfig(
|
||||||
TargetClassConfig(
|
classes=[
|
||||||
name="pippip",
|
TargetClass(
|
||||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
name="pippip",
|
||||||
assign_tags=[data.Tag(key="call_type", value="Echolocation")],
|
tags=[
|
||||||
)
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
]
|
],
|
||||||
decoder = build_sound_event_decoder(classes)
|
output_tags=[TagInfo(key="call_type", value="Echolocation")],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
|
)
|
||||||
|
decoder = build_sound_event_decoder(config)
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
assert len(tags) == 1
|
assert len(tags) == 1
|
||||||
assert tags[0].term == get_term("event")
|
assert tags[0].term == get_term("event")
|
||||||
assert tags[0].value == "Echolocation"
|
assert tags[0].value == "Echolocation"
|
||||||
|
|
||||||
# Test when output_tags is None, should fall back to tags
|
# Test when output_tags is None, should fall back to tags
|
||||||
classes = [
|
config = ClassesConfig(
|
||||||
TargetClassConfig(
|
classes=[
|
||||||
name="pippip",
|
TargetClass(
|
||||||
tags=[data.Tag(key="species", value="Pipistrellus pipistrellus")],
|
name="pippip",
|
||||||
)
|
tags=[
|
||||||
]
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
decoder = build_sound_event_decoder(classes)
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
|
)
|
||||||
|
decoder = build_sound_event_decoder(config)
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
assert len(tags) == 1
|
assert len(tags) == 1
|
||||||
assert tags[0].term == get_term("species")
|
assert tags[0].term == get_term("species")
|
||||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
# Test raise_on_unmapped=True
|
# Test raise_on_unmapped=True
|
||||||
decoder = build_sound_event_decoder(classes, raise_on_unmapped=True)
|
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
decoder("unknown_class")
|
decoder("unknown_class")
|
||||||
|
|
||||||
# Test raise_on_unmapped=False
|
# Test raise_on_unmapped=False
|
||||||
decoder = build_sound_event_decoder(classes, raise_on_unmapped=False)
|
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
|
||||||
tags = decoder("unknown_class")
|
tags = decoder("unknown_class")
|
||||||
assert len(tags) == 0
|
assert len(tags) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_decoder_from_config_valid(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
output_tags:
|
||||||
|
- key: call_type
|
||||||
|
value: Echolocation
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
"""
|
||||||
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
|
decoder = load_decoder_from_config(
|
||||||
|
temp_yaml_path,
|
||||||
|
)
|
||||||
|
tags = decoder("pippip")
|
||||||
|
assert len(tags) == 1
|
||||||
|
assert tags[0].term == get_term("call_type")
|
||||||
|
assert tags[0].value == "Echolocation"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_generic_class_tags_from_config():
|
||||||
|
config = ClassesConfig(
|
||||||
|
classes=[
|
||||||
|
TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[
|
||||||
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
generic_class=[
|
||||||
|
TagInfo(key="order", value="Chiroptera"),
|
||||||
|
TagInfo(key="call_type", value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
generic_tags = build_generic_class_tags(config)
|
||||||
|
assert len(generic_tags) == 2
|
||||||
|
assert generic_tags[0].term == get_term("order")
|
||||||
|
assert generic_tags[0].value == "Chiroptera"
|
||||||
|
assert generic_tags[1].term == get_term("call_type")
|
||||||
|
assert generic_tags[1].value == "Echolocation"
|
||||||
|
|||||||
210
tests/test_targets/test_filtering.py
Normal file
210
tests/test_targets/test_filtering.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Set
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.targets.filtering import (
|
||||||
|
FilterConfig,
|
||||||
|
FilterRule,
|
||||||
|
build_filter_from_rule,
|
||||||
|
build_sound_event_filter,
|
||||||
|
contains_tags,
|
||||||
|
does_not_have_tags,
|
||||||
|
equal_tags,
|
||||||
|
has_any_tag,
|
||||||
|
load_filter_config,
|
||||||
|
load_filter_from_config,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.terms import TagInfo, generic_class
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_annotation(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
) -> Callable[[List[str]], data.SoundEventAnnotation]:
|
||||||
|
"""Helper function to create a SoundEventAnnotation with given tags."""
|
||||||
|
|
||||||
|
def factory(tags: List[str]) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=generic_class,
|
||||||
|
value=tag,
|
||||||
|
)
|
||||||
|
for tag in tags
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
|
||||||
|
"""Helper function to create a set of data.Tag objects from a list of strings."""
|
||||||
|
return {
|
||||||
|
data.Tag(
|
||||||
|
term=generic_class,
|
||||||
|
value=tag,
|
||||||
|
)
|
||||||
|
for tag in tags
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_any_tag(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag3"])
|
||||||
|
assert has_any_tag(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag2", "tag4"])
|
||||||
|
tags = create_tag_set(["tag1", "tag3"])
|
||||||
|
assert has_any_tag(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_contains_tags(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2"])
|
||||||
|
assert contains_tags(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2", "tag3"])
|
||||||
|
assert contains_tags(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_does_not_have_tags(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag3", "tag4"])
|
||||||
|
assert does_not_have_tags(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag3"])
|
||||||
|
assert does_not_have_tags(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_equal_tags(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2"])
|
||||||
|
assert equal_tags(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2"])
|
||||||
|
assert equal_tags(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_filter_from_rule():
|
||||||
|
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
|
||||||
|
build_filter_from_rule(rule_any)
|
||||||
|
|
||||||
|
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
|
||||||
|
build_filter_from_rule(rule_all)
|
||||||
|
|
||||||
|
rule_exclude = FilterRule(
|
||||||
|
match_type="exclude", tags=[TagInfo(value="tag1")]
|
||||||
|
)
|
||||||
|
build_filter_from_rule(rule_exclude)
|
||||||
|
|
||||||
|
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
|
||||||
|
build_filter_from_rule(rule_equal)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||||
|
build_filter_from_rule(
|
||||||
|
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_filter_from_config(create_annotation):
|
||||||
|
config = FilterConfig(
|
||||||
|
rules=[
|
||||||
|
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
|
||||||
|
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
filter_from_config = build_sound_event_filter(config)
|
||||||
|
|
||||||
|
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||||
|
assert filter_from_config(annotation_pass)
|
||||||
|
|
||||||
|
annotation_fail = create_annotation(["tag1"])
|
||||||
|
assert not filter_from_config(annotation_fail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_filter_config(tmp_path: Path):
|
||||||
|
test_config_path = tmp_path / "filtering.yaml"
|
||||||
|
test_config_path.write_text(
|
||||||
|
"""
|
||||||
|
rules:
|
||||||
|
- match_type: any
|
||||||
|
tags:
|
||||||
|
- value: tag1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
config = load_filter_config(test_config_path)
|
||||||
|
assert isinstance(config, FilterConfig)
|
||||||
|
assert len(config.rules) == 1
|
||||||
|
rule = config.rules[0]
|
||||||
|
assert rule.match_type == "any"
|
||||||
|
assert len(rule.tags) == 1
|
||||||
|
assert rule.tags[0].value == "tag1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_filter_from_config(tmp_path: Path, create_annotation):
|
||||||
|
test_config_path = tmp_path / "filtering.yaml"
|
||||||
|
test_config_path.write_text(
|
||||||
|
"""
|
||||||
|
rules:
|
||||||
|
- match_type: any
|
||||||
|
tags:
|
||||||
|
- value: tag1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_result = load_filter_from_config(test_config_path)
|
||||||
|
annotation = create_annotation(["tag1", "tag3"])
|
||||||
|
assert filter_result(annotation)
|
||||||
|
|
||||||
|
test_config_path = tmp_path / "filtering.yaml"
|
||||||
|
test_config_path.write_text(
|
||||||
|
"""
|
||||||
|
rules:
|
||||||
|
- match_type: any
|
||||||
|
tags:
|
||||||
|
- value: tag2
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_result = load_filter_from_config(test_config_path)
|
||||||
|
annotation = create_annotation(["tag1", "tag3"])
|
||||||
|
assert filter_result(annotation) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_filtering_over_example_dataset(
|
||||||
|
example_annotations: List[data.ClipAnnotation],
|
||||||
|
):
|
||||||
|
targets = build_targets()
|
||||||
|
|
||||||
|
clip1 = example_annotations[0]
|
||||||
|
clip2 = example_annotations[1]
|
||||||
|
clip3 = example_annotations[2]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(
|
||||||
|
[targets.filter(sound_event) for sound_event in clip1.sound_events]
|
||||||
|
)
|
||||||
|
== 9
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(
|
||||||
|
[targets.filter(sound_event) for sound_event in clip2.sound_events]
|
||||||
|
)
|
||||||
|
== 15
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(
|
||||||
|
[targets.filter(sound_event) for sound_event in clip3.sound_events]
|
||||||
|
)
|
||||||
|
== 20
|
||||||
|
)
|
||||||
@ -8,11 +8,6 @@ from batdetect2.preprocess import (
|
|||||||
build_preprocessor,
|
build_preprocessor,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.preprocess.spectrogram import (
|
|
||||||
ScaleAmplitudeConfig,
|
|
||||||
SpectralMeanSubstractionConfig,
|
|
||||||
SpectrogramConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
DEFAULT_ANCHOR,
|
DEFAULT_ANCHOR,
|
||||||
DEFAULT_FREQUENCY_SCALE,
|
DEFAULT_FREQUENCY_SCALE,
|
||||||
@ -553,7 +548,14 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
|||||||
|
|
||||||
# Instantiate the mapper.
|
# Instantiate the mapper.
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
|
PreprocessingConfig.model_validate(
|
||||||
|
{
|
||||||
|
"spectrogram": {
|
||||||
|
"pcen": None,
|
||||||
|
"spectral_mean_substraction": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
audio_loader = build_audio_loader()
|
audio_loader = build_audio_loader()
|
||||||
mapper = PeakEnergyBBoxMapper(
|
mapper = PeakEnergyBBoxMapper(
|
||||||
@ -595,13 +597,14 @@ def test_build_roi_mapper_for_anchor_bbox():
|
|||||||
|
|
||||||
def test_build_roi_mapper_for_peak_energy_bbox():
|
def test_build_roi_mapper_for_peak_energy_bbox():
|
||||||
# Given
|
# Given
|
||||||
preproc_config = PreprocessingConfig(
|
preproc_config = PreprocessingConfig.model_validate(
|
||||||
spectrogram=SpectrogramConfig(
|
{
|
||||||
transforms=[
|
"spectrogram": {
|
||||||
ScaleAmplitudeConfig(scale="db"),
|
"pcen": None,
|
||||||
SpectralMeanSubstractionConfig(),
|
"spectral_mean_substraction": True,
|
||||||
]
|
"scale": "dB",
|
||||||
),
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
config = PeakEnergyBBoxMapperConfig(
|
config = PeakEnergyBBoxMapperConfig(
|
||||||
loading_buffer=0.99,
|
loading_buffer=0.99,
|
||||||
|
|||||||
@ -1,55 +1,46 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from soundevent import data, terms
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
|
from batdetect2.targets.terms import get_term_from_key
|
||||||
|
|
||||||
|
|
||||||
def test_can_override_default_roi_mapper_per_class(
|
def test_can_override_default_roi_mapper_per_class(
|
||||||
create_temp_yaml: Callable[..., Path],
|
create_temp_yaml: Callable[..., Path],
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
|
sample_term_registry,
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
detection_target:
|
|
||||||
name: bat
|
|
||||||
match_if:
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: order
|
|
||||||
value: Chiroptera
|
|
||||||
assign_tags:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
|
|
||||||
classification_targets:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
|
|
||||||
- name: myomyo
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
|
|
||||||
roi:
|
roi:
|
||||||
name: anchor_bbox
|
name: anchor_bbox
|
||||||
anchor: bottom-left
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
"""
|
"""
|
||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = load_target_config(config_path)
|
config = load_target_config(config_path)
|
||||||
targets = build_targets(config)
|
targets = build_targets(config, term_registry=sample_term_registry)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
species = terms.get_term("species")
|
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||||
assert species is not None
|
|
||||||
|
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
@ -71,47 +62,38 @@ def test_can_override_default_roi_mapper_per_class(
|
|||||||
# TODO: rename this test function
|
# TODO: rename this test function
|
||||||
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||||
create_temp_yaml,
|
create_temp_yaml,
|
||||||
|
sample_term_registry,
|
||||||
recording,
|
recording,
|
||||||
):
|
):
|
||||||
yaml_content = """
|
yaml_content = """
|
||||||
detection_target:
|
|
||||||
name: bat
|
|
||||||
match_if:
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: order
|
|
||||||
value: Chiroptera
|
|
||||||
assign_tags:
|
|
||||||
- key: order
|
|
||||||
value: Chiroptera
|
|
||||||
|
|
||||||
classification_targets:
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
|
|
||||||
- name: myomyo
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
|
|
||||||
roi:
|
roi:
|
||||||
name: anchor_bbox
|
name: anchor_bbox
|
||||||
anchor: bottom-left
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
"""
|
"""
|
||||||
config_path = create_temp_yaml(yaml_content)
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
config = load_target_config(config_path)
|
config = load_target_config(config_path)
|
||||||
targets = build_targets(config)
|
targets = build_targets(config, term_registry=sample_term_registry)
|
||||||
|
|
||||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
species = terms.get_term("species")
|
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||||
assert species is not None
|
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
|
|||||||
179
tests/test_targets/test_terms.py
Normal file
179
tests/test_targets/test_terms.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import terms
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
TagInfo,
|
||||||
|
TermRegistry,
|
||||||
|
load_terms_from_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_initialization():
|
||||||
|
registry = TermRegistry()
|
||||||
|
assert registry._terms == {}
|
||||||
|
|
||||||
|
initial_terms = {
|
||||||
|
"test_term": data.Term(name="test", label="Test", definition="test")
|
||||||
|
}
|
||||||
|
registry = TermRegistry(terms=initial_terms)
|
||||||
|
assert registry._terms == initial_terms
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_add_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = data.Term(name="test", label="Test", definition="test")
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
assert registry._terms["test_key"] == term
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_get_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = data.Term(name="test", label="Test", definition="test")
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
retrieved_term = registry.get_term("test_key")
|
||||||
|
assert retrieved_term == term
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_add_custom_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = registry.add_custom_term(
|
||||||
|
"custom_key", name="custom", label="Custom", definition="A custom term"
|
||||||
|
)
|
||||||
|
assert registry._terms["custom_key"] == term
|
||||||
|
assert term.name == "custom"
|
||||||
|
assert term.label == "Custom"
|
||||||
|
assert term.definition == "A custom term"
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_add_duplicate_term():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term = data.Term(name="test", label="Test", definition="test")
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
registry.add_term("test_key", term)
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_get_term_not_found():
|
||||||
|
registry = TermRegistry()
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
registry.get_term("non_existent_key")
|
||||||
|
|
||||||
|
|
||||||
|
def test_term_registry_get_keys():
|
||||||
|
registry = TermRegistry()
|
||||||
|
term1 = data.Term(name="test1", label="Test1", definition="test")
|
||||||
|
term2 = data.Term(name="test2", label="Test2", definition="test")
|
||||||
|
registry.add_term("key1", term1)
|
||||||
|
registry.add_term("key2", term2)
|
||||||
|
keys = registry.get_keys()
|
||||||
|
assert set(keys) == {"key1", "key2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_term_from_key():
|
||||||
|
term = terms.get_term_from_key("event")
|
||||||
|
assert term == terms.call_type
|
||||||
|
|
||||||
|
custom_registry = TermRegistry()
|
||||||
|
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||||
|
custom_registry.add_term("custom_key", custom_term)
|
||||||
|
term = terms.get_term_from_key("custom_key", term_registry=custom_registry)
|
||||||
|
assert term == custom_term
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_term_keys():
|
||||||
|
keys = terms.get_term_keys()
|
||||||
|
assert "event" in keys
|
||||||
|
assert "individual" in keys
|
||||||
|
assert terms.GENERIC_CLASS_KEY in keys
|
||||||
|
|
||||||
|
custom_registry = TermRegistry()
|
||||||
|
custom_term = data.Term(name="custom", label="Custom", definition="test")
|
||||||
|
custom_registry.add_term("custom_key", custom_term)
|
||||||
|
keys = terms.get_term_keys(term_registry=custom_registry)
|
||||||
|
assert "custom_key" in keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_tag_info_and_get_tag_from_info():
|
||||||
|
tag_info = TagInfo(value="Myotis myotis", key="event")
|
||||||
|
tag = terms.get_tag_from_info(tag_info)
|
||||||
|
assert tag.value == "Myotis myotis"
|
||||||
|
assert tag.term == terms.call_type
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tag_from_info_key_not_found():
|
||||||
|
tag_info = TagInfo(value="test", key="non_existent_key")
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
terms.get_tag_from_info(tag_info)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config(tmp_path):
|
||||||
|
term_registry = TermRegistry()
|
||||||
|
config_data = {
|
||||||
|
"terms": [
|
||||||
|
{
|
||||||
|
"key": "species",
|
||||||
|
"name": "dwc:scientificName",
|
||||||
|
"label": "Scientific Name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "my_custom_term",
|
||||||
|
"name": "soundevent:custom_term",
|
||||||
|
"definition": "Describes a specific project attribute",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
loaded_terms = load_terms_from_config(
|
||||||
|
config_file,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
assert "species" in loaded_terms
|
||||||
|
assert "my_custom_term" in loaded_terms
|
||||||
|
assert loaded_terms["species"].name == "dwc:scientificName"
|
||||||
|
assert loaded_terms["my_custom_term"].name == "soundevent:custom_term"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config_file_not_found():
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_terms_from_config("non_existent_file.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config_validation_error(tmp_path):
|
||||||
|
config_data = {
|
||||||
|
"terms": [
|
||||||
|
{
|
||||||
|
"key": "species",
|
||||||
|
"uri": "dwc:scientificName",
|
||||||
|
"label": 123,
|
||||||
|
}, # Invalid label type
|
||||||
|
]
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
load_terms_from_config(config_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_terms_from_config_key_already_exists(tmp_path):
|
||||||
|
config_data = {
|
||||||
|
"terms": [
|
||||||
|
{
|
||||||
|
"key": "event",
|
||||||
|
"uri": "dwc:scientificName",
|
||||||
|
"label": "Scientific Name",
|
||||||
|
}, # Duplicate key
|
||||||
|
]
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
load_terms_from_config(config_file)
|
||||||
363
tests/test_targets/test_transform.py
Normal file
363
tests/test_targets/test_transform.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import (
|
||||||
|
DeriveTagRule,
|
||||||
|
MapValueRule,
|
||||||
|
ReplaceRule,
|
||||||
|
TagInfo,
|
||||||
|
TransformConfig,
|
||||||
|
build_transformation_from_config,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.terms import TermRegistry
|
||||||
|
from batdetect2.targets.transform import (
|
||||||
|
DerivationRegistry,
|
||||||
|
build_transform_from_rule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def term_registry():
|
||||||
|
return TermRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def derivation_registry():
|
||||||
|
return DerivationRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def term1(term_registry: TermRegistry) -> data.Term:
|
||||||
|
return term_registry.add_custom_term(key="term1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def term2(term_registry: TermRegistry) -> data.Term:
|
||||||
|
return term_registry.add_custom_term(key="term2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def annotation(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
term1: data.Term,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule_no_match(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"other_value": "value2"},
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_rule(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term2: data.Term,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
rule = ReplaceRule(
|
||||||
|
rule_type="replace",
|
||||||
|
original=TagInfo(key="term1", value="value1"),
|
||||||
|
replacement=TagInfo(key="term2", value="value2"),
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].term == term2
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_rule_no_match(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term2: data.Term,
|
||||||
|
):
|
||||||
|
rule = ReplaceRule(
|
||||||
|
rule_type="replace",
|
||||||
|
original=TagInfo(key="term1", value="wrong_value"),
|
||||||
|
replacement=TagInfo(key="term2", value="value2"),
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].key == "term1"
|
||||||
|
assert transformed_annotation.tags[0].term != term2
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_transformation_from_config(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
config = TransformConfig(
|
||||||
|
rules=[
|
||||||
|
MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
),
|
||||||
|
ReplaceRule(
|
||||||
|
rule_type="replace",
|
||||||
|
original=TagInfo(key="term2", value="value2"),
|
||||||
|
replacement=TagInfo(key="term3", value="value3"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
term_registry.add_custom_term("term2")
|
||||||
|
term_registry.add_custom_term("term3")
|
||||||
|
transform = build_transformation_from_config(
|
||||||
|
config,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform(annotation)
|
||||||
|
assert transformed_annotation.tags[0].key == "term1"
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term1
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_keep_source_false(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
keep_source=False,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 1
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_target_term(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
term2: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
target_term_key="term2",
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term2
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_import_derivation(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
# Create a dummy derivation function in a temporary file
|
||||||
|
derivation_module_path = (
|
||||||
|
tmp_path / "temp_derivation.py"
|
||||||
|
) # Changed to /tmp since /home/santiago is not writable
|
||||||
|
derivation_module_path.write_text(
|
||||||
|
"""
|
||||||
|
def my_imported_derivation(x: str) -> str:
|
||||||
|
return x + "_imported"
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Ensure the temporary file is importable by adding its directory to sys.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="temp_derivation.my_imported_derivation",
|
||||||
|
import_derivation=True,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term1
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_imported"
|
||||||
|
|
||||||
|
# Clean up the temporary file and sys.path
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="nonexistent_derivation",
|
||||||
|
)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_transform_from_rule_invalid_rule_type():
|
||||||
|
class InvalidRule:
|
||||||
|
rule_type = "invalid"
|
||||||
|
|
||||||
|
rule = InvalidRule() # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
build_transform_from_rule(rule) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule_target_term(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term2: data.Term,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
target_term_key="term2",
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].term == term2
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule_target_term_none(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
target_term_key=None,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_target_term_none(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
target_term_key=None,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term1
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_transformation_from_config_empty(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
):
|
||||||
|
config = TransformConfig(rules=[])
|
||||||
|
transform = build_transformation_from_config(config)
|
||||||
|
transformed_annotation = transform(annotation)
|
||||||
|
assert transformed_annotation == annotation
|
||||||
170
tests/test_train/test_augmentations.py
Normal file
170
tests/test_train/test_augmentations.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
add_echo,
|
||||||
|
mix_audio,
|
||||||
|
)
|
||||||
|
from batdetect2.train.clips import select_subclip
|
||||||
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def test_mix_examples(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
|
create_recording: Callable[..., data.Recording],
|
||||||
|
):
|
||||||
|
recording1 = create_recording()
|
||||||
|
recording2 = create_recording()
|
||||||
|
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
|
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
||||||
|
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||||
|
|
||||||
|
example1 = generate_train_example(
|
||||||
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
labeller=sample_labeller,
|
||||||
|
)
|
||||||
|
example2 = generate_train_example(
|
||||||
|
clip_annotation_2,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
labeller=sample_labeller,
|
||||||
|
)
|
||||||
|
|
||||||
|
mixed = mix_audio(
|
||||||
|
example1,
|
||||||
|
example2,
|
||||||
|
weight=0.3,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||||
|
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||||
|
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||||
|
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||||
|
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||||
|
def test_mix_examples_of_different_durations(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
|
create_recording: Callable[..., data.Recording],
|
||||||
|
duration1: float,
|
||||||
|
duration2: float,
|
||||||
|
):
|
||||||
|
recording1 = create_recording()
|
||||||
|
recording2 = create_recording()
|
||||||
|
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
|
||||||
|
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
|
||||||
|
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||||
|
|
||||||
|
example1 = generate_train_example(
|
||||||
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
labeller=sample_labeller,
|
||||||
|
)
|
||||||
|
example2 = generate_train_example(
|
||||||
|
clip_annotation_2,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
labeller=sample_labeller,
|
||||||
|
)
|
||||||
|
|
||||||
|
mixed = mix_audio(
|
||||||
|
example1,
|
||||||
|
example2,
|
||||||
|
weight=0.3,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mixed.spectrogram.shape == example1.spectrogram.shape
|
||||||
|
assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape
|
||||||
|
assert mixed.size_heatmap.shape == example1.size_heatmap.shape
|
||||||
|
assert mixed.class_heatmap.shape == example1.class_heatmap.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_echo(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
|
create_recording: Callable[..., data.Recording],
|
||||||
|
):
|
||||||
|
recording1 = create_recording()
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
|
||||||
|
original = generate_train_example(
|
||||||
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
labeller=sample_labeller,
|
||||||
|
)
|
||||||
|
with_echo = add_echo(
|
||||||
|
original,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
delay=0.1,
|
||||||
|
weight=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert with_echo.spectrogram.shape == original.spectrogram.shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
with_echo.size_heatmap,
|
||||||
|
original.size_heatmap,
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
with_echo.class_heatmap,
|
||||||
|
original.class_heatmap,
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
with_echo.detection_heatmap,
|
||||||
|
original.detection_heatmap,
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_selected_random_subclip_has_the_correct_width(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
|
create_recording: Callable[..., data.Recording],
|
||||||
|
):
|
||||||
|
recording1 = create_recording()
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
|
||||||
|
original = generate_train_example(
|
||||||
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
labeller=sample_labeller,
|
||||||
|
)
|
||||||
|
|
||||||
|
subclip = select_subclip(
|
||||||
|
original,
|
||||||
|
input_samplerate=256_000,
|
||||||
|
output_samplerate=1000,
|
||||||
|
start=0,
|
||||||
|
duration=0.512,
|
||||||
|
)
|
||||||
|
assert subclip.spectrogram.shape[1] == 512
|
||||||
27
tests/test_train/test_clips.py
Normal file
27
tests/test_train/test_clips.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.train import generate_train_example
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
ClipLabeller,
|
||||||
|
ClipperProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_clip_size_is_correct(
|
||||||
|
sample_clipper: ClipperProtocol,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
):
|
||||||
|
example = generate_train_example(
|
||||||
|
clip_annotation=clip_annotation,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
|
preprocessor=sample_preprocessor,
|
||||||
|
labeller=sample_labeller,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip, _, _ = sample_clipper(example)
|
||||||
|
assert clip.spectrogram.shape == (1, 128, 256)
|
||||||
@ -5,6 +5,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||||
|
from batdetect2.targets.terms import TagInfo
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
@ -25,8 +26,7 @@ clip = data.Clip(
|
|||||||
|
|
||||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||||
sample_target_config: TargetConfig,
|
sample_target_config: TargetConfig,
|
||||||
pippip_tag: data.Tag,
|
pippip_tag: TagInfo,
|
||||||
bat_tag: data.Tag,
|
|
||||||
):
|
):
|
||||||
config = sample_target_config.model_copy(
|
config = sample_target_config.model_copy(
|
||||||
update=dict(
|
update=dict(
|
||||||
@ -49,14 +49,14 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
coordinates=[10, 10, 20, 30],
|
coordinates=[10, 10, 20, 30],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
tags=[pippip_tag, bat_tag],
|
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
torch.rand([1, 100, 100]),
|
torch.rand([100, 100]),
|
||||||
min_freq=0,
|
min_freq=0,
|
||||||
max_freq=100,
|
max_freq=100,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -67,4 +67,4 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
assert size_heatmap[1, 10, 10] == 20
|
assert size_heatmap[1, 10, 10] == 20
|
||||||
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
||||||
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
||||||
assert detection_heatmap[0, 10, 10] == 1.0
|
assert detection_heatmap[10, 10] == 1.0
|
||||||
|
|||||||
@ -4,16 +4,14 @@ import lightning as L
|
|||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models import build_model
|
|
||||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
model = build_model()
|
|
||||||
config = FullTrainingConfig()
|
config = FullTrainingConfig()
|
||||||
return build_training_module(model, config=config)
|
return build_training_module(config)
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
@ -34,14 +32,14 @@ def test_can_save_checkpoint(
|
|||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
|
||||||
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
|
wav = torch.tensor(sample_audio_loader.load_clip(clip))
|
||||||
|
|
||||||
spec1 = module.model.preprocessor(wav)
|
spec1 = module.model.preprocessor(wav)
|
||||||
spec2 = recovered.model.preprocessor(wav)
|
spec2 = recovered.model.preprocessor(wav)
|
||||||
|
|
||||||
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||||
|
|
||||||
output1 = module(spec1.unsqueeze(0))
|
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
|
||||||
output2 = recovered(spec2.unsqueeze(0))
|
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
|
||||||
|
|
||||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|||||||
230
tests/test_train/test_preprocessing.py
Normal file
230
tests/test_train/test_preprocessing.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.terms import get_term
|
||||||
|
|
||||||
|
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||||
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
from batdetect2.typing import ModelOutput
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def build_from_config(
|
||||||
|
create_temp_yaml,
|
||||||
|
):
|
||||||
|
def build(yaml_content):
|
||||||
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
|
targets_config = load_target_config(config_path, field="targets")
|
||||||
|
preprocessing_config = load_preprocessing_config(
|
||||||
|
config_path,
|
||||||
|
field="preprocessing",
|
||||||
|
)
|
||||||
|
labels_config = load_label_config(config_path, field="labels")
|
||||||
|
postprocessing_config = load_postprocess_config(
|
||||||
|
config_path,
|
||||||
|
field="postprocessing",
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = build_targets(targets_config)
|
||||||
|
preprocessor = build_preprocessor(preprocessing_config)
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
targets=targets,
|
||||||
|
config=labels_config,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=postprocessing_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return targets, preprocessor, labeller, postprocessor
|
||||||
|
|
||||||
|
return build
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoding_decoding_roundtrip_recovers_object(
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
build_from_config,
|
||||||
|
recording,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
labels:
|
||||||
|
targets:
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
preprocessing:
|
||||||
|
"""
|
||||||
|
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||||
|
],
|
||||||
|
)
|
||||||
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
|
|
||||||
|
encoded = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
sample_audio_loader,
|
||||||
|
preprocessor,
|
||||||
|
labeller,
|
||||||
|
)
|
||||||
|
predictions = postprocessor.get_predictions(
|
||||||
|
ModelOutput(
|
||||||
|
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||||
|
0
|
||||||
|
),
|
||||||
|
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||||
|
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||||
|
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||||
|
),
|
||||||
|
[clip],
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert isinstance(predictions, data.ClipPrediction)
|
||||||
|
assert len(predictions.sound_events) == 1
|
||||||
|
|
||||||
|
recovered = predictions.sound_events[0]
|
||||||
|
assert recovered.sound_event.geometry is not None
|
||||||
|
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||||
|
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||||
|
recovered.sound_event.geometry.coordinates
|
||||||
|
)
|
||||||
|
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||||
|
geometry.coordinates
|
||||||
|
)
|
||||||
|
|
||||||
|
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||||
|
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||||
|
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||||
|
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||||
|
|
||||||
|
assert len(recovered.tags) == 2
|
||||||
|
|
||||||
|
predicted_species_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert predicted_species_tag is not None
|
||||||
|
assert predicted_species_tag.score == 1
|
||||||
|
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
|
predicted_order_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert predicted_order_tag is not None
|
||||||
|
assert predicted_order_tag.score == 1
|
||||||
|
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
build_from_config,
|
||||||
|
recording,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
labels:
|
||||||
|
targets:
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
preprocessing:
|
||||||
|
"""
|
||||||
|
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||||
|
)
|
||||||
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
|
|
||||||
|
encoded = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
sample_audio_loader,
|
||||||
|
preprocessor,
|
||||||
|
labeller,
|
||||||
|
)
|
||||||
|
predictions = postprocessor.get_predictions(
|
||||||
|
ModelOutput(
|
||||||
|
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||||
|
0
|
||||||
|
),
|
||||||
|
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||||
|
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||||
|
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||||
|
),
|
||||||
|
[clip],
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert isinstance(predictions, data.ClipPrediction)
|
||||||
|
assert len(predictions.sound_events) == 1
|
||||||
|
|
||||||
|
recovered = predictions.sound_events[0]
|
||||||
|
assert recovered.sound_event.geometry is not None
|
||||||
|
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||||
|
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||||
|
recovered.sound_event.geometry.coordinates
|
||||||
|
)
|
||||||
|
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||||
|
geometry.coordinates
|
||||||
|
)
|
||||||
|
|
||||||
|
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||||
|
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||||
|
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||||
|
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||||
|
|
||||||
|
assert len(recovered.tags) == 2
|
||||||
|
|
||||||
|
predicted_species_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert predicted_species_tag is not None
|
||||||
|
assert predicted_species_tag.score == 1
|
||||||
|
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||||
|
|
||||||
|
predicted_order_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert predicted_order_tag is not None
|
||||||
|
assert predicted_order_tag.score == 1
|
||||||
|
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||||
Loading…
Reference in New Issue
Block a user