mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
WIP
This commit is contained in:
parent
e9e1f7ce2f
commit
f7d6516550
@ -215,6 +215,7 @@ def annotation_to_sound_event(
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
audio_dir = audio_dir or Path.cwd()
|
||||
@ -227,6 +228,12 @@ def file_annotation_to_clip(
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
|
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
@ -8,7 +8,11 @@ from batdetect2.compat.data import (
|
||||
load_annotation_project_from_dir,
|
||||
load_annotation_project_from_file,
|
||||
)
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"load_datasets_from_config",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2AnnotationFiles(BaseConfig):
|
||||
@ -23,7 +27,7 @@ class BatDetect2AnnotationFile(BaseConfig):
|
||||
|
||||
class AOEFAnnotationFile(BaseConfig):
|
||||
format: Literal["aoef"] = "aoef"
|
||||
annotations_file: Path
|
||||
path: Path
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
@ -44,25 +48,39 @@ class DatasetsConfig(BaseConfig):
|
||||
test: List[DatasetInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
def load_dataset(info: DatasetInfo) -> data.AnnotationProject:
|
||||
def load_dataset(
|
||||
info: DatasetInfo,
|
||||
audio_dir: Optional[Path] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationProject:
|
||||
audio_dir = (
|
||||
info.audio_dir if base_dir is None else base_dir / info.audio_dir
|
||||
)
|
||||
|
||||
path = (
|
||||
info.annotations.path
|
||||
if base_dir is None
|
||||
else base_dir / info.annotations.path
|
||||
)
|
||||
|
||||
if info.annotations.format == "batdetect2":
|
||||
return load_annotation_project_from_dir(
|
||||
info.annotations.path,
|
||||
path,
|
||||
name=info.name,
|
||||
audio_dir=info.audio_dir,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
if info.annotations.format == "batdetect2_file":
|
||||
return load_annotation_project_from_file(
|
||||
info.annotations.path,
|
||||
path,
|
||||
name=info.name,
|
||||
audio_dir=info.audio_dir,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
if info.annotations.format == "aoef":
|
||||
return io.load( # type: ignore
|
||||
info.annotations.annotations_file,
|
||||
audio_dir=info.audio_dir,
|
||||
info.annotations.path,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
@ -72,16 +90,30 @@ def load_dataset(info: DatasetInfo) -> data.AnnotationProject:
|
||||
|
||||
def load_datasets(
|
||||
config: DatasetsConfig,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> Tuple[List[data.ClipAnnotation], List[data.ClipAnnotation]]:
|
||||
test_annotations = []
|
||||
train_annotations = []
|
||||
|
||||
for dataset in config.train:
|
||||
project = load_dataset(dataset)
|
||||
project = load_dataset(dataset, base_dir=base_dir)
|
||||
train_annotations.extend(project.clip_annotations)
|
||||
|
||||
for dataset in config.test:
|
||||
project = load_dataset(dataset)
|
||||
project = load_dataset(dataset, base_dir=base_dir)
|
||||
test_annotations.extend(project.clip_annotations)
|
||||
|
||||
return train_annotations, test_annotations
|
||||
|
||||
|
||||
def load_datasets_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
config = load_config(
|
||||
path=path,
|
||||
schema=DatasetsConfig,
|
||||
field=field,
|
||||
)
|
||||
return load_datasets(config, base_dir=base_dir)
|
||||
|
@ -28,7 +28,7 @@ class ModelType(str, Enum):
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
num_features: int = 128
|
||||
num_features: int = 32
|
||||
|
||||
|
||||
def get_backbone(
|
||||
|
@ -41,8 +41,8 @@ class BBoxHead(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
self.feature_extractor.out_channels,
|
||||
2,
|
||||
in_channels=self.in_channels,
|
||||
out_channels=2,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Module for postprocessing model outputs."""
|
||||
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -30,15 +30,12 @@ class PostprocessConfig(BaseModel):
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
TagFunction = Callable[[int], List[data.Tag]]
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
classes: List[str],
|
||||
decoder: Callable[[str], List[data.Tag]],
|
||||
config: PostprocessConfig,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Postprocesses model outputs to generate clip predictions.
|
||||
|
||||
@ -68,6 +65,9 @@ def postprocess_model_outputs(
|
||||
ValueError
|
||||
If the number of predictions does not match the number of clips.
|
||||
"""
|
||||
|
||||
config = config or PostprocessConfig()
|
||||
|
||||
num_predictions = len(outputs.detection_probs)
|
||||
|
||||
if num_predictions == 0:
|
||||
@ -189,7 +189,7 @@ def compute_sound_events_from_outputs(
|
||||
[
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=class_score.item(),
|
||||
score=max(min(class_score.item(), 1), 0),
|
||||
)
|
||||
for tag in corresponding_tags
|
||||
]
|
||||
@ -220,7 +220,7 @@ def compute_sound_events_from_outputs(
|
||||
predictions.append(
|
||||
data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=score.item(),
|
||||
score=max(min(score.item(), 1), 0),
|
||||
tags=predicted_tags,
|
||||
)
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ from batdetect2.configs import BaseConfig
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
SCALE_RAW_AUDIO = False
|
||||
DEFAULT_DURATION = 1
|
||||
DEFAULT_DURATION = None
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
|
@ -71,11 +71,12 @@ def compute_spectrogram(
|
||||
if config.size.divide_factor:
|
||||
# Need to pad the audio to make sure the spectrogram has a
|
||||
# width compatible with the divide factor
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
wav = pad_audio(
|
||||
wav,
|
||||
window_duration=config.fft.window_duration,
|
||||
window_overlap=config.fft.window_overlap,
|
||||
divide_factor=config.size.divide_factor,
|
||||
divide_factor=int(config.size.divide_factor / resize_factor),
|
||||
)
|
||||
|
||||
spec = stft(
|
||||
|
@ -20,6 +20,37 @@ class AugmentationConfig(BaseConfig):
|
||||
class SubclipConfig(BaseConfig):
|
||||
enable: bool = True
|
||||
duration: Optional[float] = None
|
||||
width: Optional[int] = 512
|
||||
|
||||
|
||||
def adjust_dataset_width(
|
||||
example: xr.Dataset,
|
||||
duration: Optional[float] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> xr.Dataset:
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
|
||||
if width is None:
|
||||
if duration is None:
|
||||
raise ValueError("Either duration or width must be provided")
|
||||
|
||||
width = int(np.floor(duration / step))
|
||||
|
||||
adjusted_arrays = {
|
||||
name: ops.adjust_dim_width(array, "time", width)
|
||||
for name, array in example.items()
|
||||
if name != "audio"
|
||||
}
|
||||
|
||||
ratio = width / example.spectrogram.sizes["time"]
|
||||
audio_width = int(example.audio.sizes["audio_time"] * ratio)
|
||||
adjusted_arrays["audio"] = ops.adjust_dim_width(
|
||||
example["audio"],
|
||||
"audio_time",
|
||||
audio_width,
|
||||
)
|
||||
|
||||
return xr.Dataset(data_vars=adjusted_arrays)
|
||||
|
||||
|
||||
def select_random_subclip(
|
||||
@ -30,6 +61,7 @@ def select_random_subclip(
|
||||
) -> xr.Dataset:
|
||||
"""Select a random subclip from a clip."""
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||
|
||||
if width is None:
|
||||
if duration is None:
|
||||
@ -41,15 +73,19 @@ def select_random_subclip(
|
||||
duration = width * step
|
||||
|
||||
if start_time is None:
|
||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||
start_time = np.random.uniform(start, stop - duration)
|
||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||
|
||||
if start_time + duration > stop:
|
||||
example = adjust_dataset_width(example, width=width)
|
||||
|
||||
start_index = arrays.get_coord_index(
|
||||
example, # type: ignore
|
||||
"time",
|
||||
start_time,
|
||||
)
|
||||
|
||||
end_index = start_index + width - 1
|
||||
|
||||
start_time = example.time.values[start_index]
|
||||
end_time = example.time.values[end_index]
|
||||
|
||||
@ -78,8 +114,15 @@ def mix_examples(
|
||||
if weight is None:
|
||||
weight = np.random.uniform(min_weight, max_weight)
|
||||
|
||||
audio2 = other["audio"].values
|
||||
audio1 = ops.adjust_dim_width(example["audio"], "audio_time", len(audio2))
|
||||
audio1 = example["audio"]
|
||||
|
||||
audio2 = ops.adjust_dim_width(
|
||||
other["audio"], "audio_time", len(audio1)
|
||||
).values
|
||||
|
||||
if len(audio2) > len(audio1):
|
||||
audio2 = audio2[: len(audio1)]
|
||||
|
||||
combined = weight * audio1 + (1 - weight) * audio2
|
||||
|
||||
spec = compute_spectrogram(
|
||||
@ -104,11 +147,16 @@ def mix_examples(
|
||||
return xr.Dataset(
|
||||
{
|
||||
"audio": combined,
|
||||
"spectrogram": spec,
|
||||
"spectrogram": xr.DataArray(
|
||||
data=spec.data,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
"size": size_heatmap,
|
||||
}
|
||||
},
|
||||
attrs=example.attrs,
|
||||
)
|
||||
|
||||
|
||||
@ -146,7 +194,14 @@ def add_echo(
|
||||
config=config.spectrogram,
|
||||
)
|
||||
|
||||
return example.assign(audio=audio, spectrogram=spectrogram)
|
||||
return example.assign(
|
||||
audio=audio,
|
||||
spectrogram=xr.DataArray(
|
||||
data=spectrogram.data,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class VolumeAugmentationConfig(AugmentationConfig):
|
||||
@ -340,15 +395,17 @@ def augment_example(
|
||||
example = select_random_subclip(
|
||||
example,
|
||||
duration=config.subclip.duration,
|
||||
width=config.subclip.width,
|
||||
)
|
||||
|
||||
if should_apply(config.mix) and others is not None:
|
||||
if should_apply(config.mix) and (others is not None):
|
||||
other = others()
|
||||
|
||||
if config.subclip.enable:
|
||||
other = select_random_subclip(
|
||||
other,
|
||||
duration=config.subclip.duration,
|
||||
width=config.subclip.width,
|
||||
)
|
||||
|
||||
example = mix_examples(
|
||||
|
@ -74,8 +74,20 @@ class LabeledDataset(Dataset):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
||||
return cls(get_files(directory, extension))
|
||||
def from_directory(
|
||||
cls,
|
||||
directory: PathLike,
|
||||
extension: str = ".nc",
|
||||
augment: bool = False,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
augmentation_config: Optional[AugmentationsConfig] = None,
|
||||
):
|
||||
return cls(
|
||||
get_files(directory, extension),
|
||||
augment=augment,
|
||||
preprocessing_config=preprocessing_config,
|
||||
augmentation_config=augmentation_config,
|
||||
)
|
||||
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
idx = np.random.randint(0, len(self))
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -22,7 +23,7 @@ def generate_heatmaps(
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
class_names: List[str],
|
||||
encoder: Callable[[data.SoundEventAnnotation], Optional[str]],
|
||||
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
|
||||
target_sigma: float = 3.0,
|
||||
position: Positions = "bottom-left",
|
||||
time_scale: float = 1000.0,
|
||||
@ -64,23 +65,29 @@ def generate_heatmaps(
|
||||
time, frequency = geometry.get_geometry_point(geom, position=position)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
detection_heatmap = arrays.set_value_at_pos(
|
||||
detection_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
try:
|
||||
detection_heatmap = arrays.set_value_at_pos(
|
||||
detection_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
continue
|
||||
|
||||
# Set the size of the sound event at the position in the size heatmap
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(end_time - start_time) * time_scale,
|
||||
(high_freq - low_freq) * frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
size,
|
||||
@ -89,14 +96,12 @@ def generate_heatmaps(
|
||||
)
|
||||
|
||||
# Get the class name of the sound event
|
||||
class_name = encoder(sound_event_annotation)
|
||||
class_name = encoder(sound_event_annotation.tags)
|
||||
|
||||
if class_name is None:
|
||||
# If the label is None skip the sound event
|
||||
continue
|
||||
|
||||
# Set 1.0 at the position and category of the sound event in the class
|
||||
# heatmap
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
|
@ -6,6 +6,7 @@ from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.plot import detection
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ from batdetect2.models import (
|
||||
get_backbone,
|
||||
)
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.post_process import PostprocessConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
from batdetect2.train.losses import LossConfig, compute_loss
|
||||
@ -37,6 +38,9 @@ class ModuleConfig(BaseConfig):
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocessing: PostprocessConfig = Field(
|
||||
default_factory=PostprocessConfig
|
||||
)
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
@ -50,8 +54,9 @@ class DetectorModel(L.LightningModule):
|
||||
self.config = config or ModuleConfig()
|
||||
self.save_hyperparameters()
|
||||
|
||||
size = self.config.preprocessing.spectrogram.size
|
||||
self.backbone = get_backbone(
|
||||
input_height=self.config.preprocessing.spectrogram.size.height,
|
||||
input_height=int(size.height * (size.resize_factor or 1)),
|
||||
config=self.config.backbone,
|
||||
)
|
||||
|
||||
@ -62,11 +67,13 @@ class DetectorModel(L.LightningModule):
|
||||
|
||||
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
||||
|
||||
conf = self.training_config.loss.classification
|
||||
conf = self.config.train.loss.classification
|
||||
self.class_weights = (
|
||||
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||
)
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
||||
@ -86,8 +93,33 @@ class DetectorModel(L.LightningModule):
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("train/loss/detection", losses.total, logger=True)
|
||||
self.log("train/loss/size", losses.total, logger=True)
|
||||
self.log("train/loss/classification", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||
outputs = self.forward(batch.spec)
|
||||
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("val/loss/detection", losses.total, logger=True)
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
print(dataloaders)
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
conf = self.config.train.optimizer
|
||||
optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate)
|
||||
|
@ -62,12 +62,16 @@ def generate_train_example(
|
||||
include=config.target.include,
|
||||
exclude=config.target.exclude,
|
||||
)
|
||||
|
||||
selected_events = [
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
|
||||
encoder = build_encoder(
|
||||
config.target.classes,
|
||||
replacement_rules=config.target.replace,
|
||||
)
|
||||
class_names = get_class_names(config.target.classes)
|
||||
encoder = build_encoder(config.target.classes)
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
selected_events,
|
||||
@ -172,5 +176,11 @@ def preprocess_single_annotation(
|
||||
if path.is_file() and replace:
|
||||
path.unlink()
|
||||
|
||||
sample = generate_train_example(clip_annotation, config=config)
|
||||
try:
|
||||
sample = generate_train_example(clip_annotation, config=config)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||
) from error
|
||||
|
||||
save_to_file(sample, path)
|
||||
|
@ -1,13 +1,22 @@
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
|
||||
|
||||
class ReplaceConfig(BaseConfig):
|
||||
"""Configuration for replacing tags."""
|
||||
|
||||
original: TagInfo
|
||||
replacement: TagInfo
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Configuration for target generation."""
|
||||
|
||||
@ -23,6 +32,7 @@ class TargetConfig(BaseConfig):
|
||||
include: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
|
||||
)
|
||||
|
||||
exclude: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [
|
||||
TagInfo(key="class", value=""),
|
||||
@ -31,6 +41,8 @@ class TargetConfig(BaseConfig):
|
||||
]
|
||||
)
|
||||
|
||||
replace: Optional[List[ReplaceConfig]] = None
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
include: Optional[List[TagInfo]] = None,
|
||||
@ -57,9 +69,24 @@ def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||
return sorted({get_tag_label(tag) for tag in classes})
|
||||
|
||||
|
||||
def build_replacer(
|
||||
rules: List[ReplaceConfig],
|
||||
) -> Callable[[data.Tag], data.Tag]:
|
||||
mapping = {
|
||||
get_tag_from_info(rule.original): get_tag_from_info(rule.replacement)
|
||||
for rule in rules
|
||||
}
|
||||
|
||||
def replacer(tag: data.Tag) -> data.Tag:
|
||||
return mapping.get(tag, tag)
|
||||
|
||||
return replacer
|
||||
|
||||
|
||||
def build_encoder(
|
||||
classes: List[TagInfo],
|
||||
) -> Callable[[data.SoundEventAnnotation], Optional[str]]:
|
||||
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
||||
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||
|
||||
tag_mapping = {
|
||||
@ -67,12 +94,16 @@ def build_encoder(
|
||||
for tag, tag_info in zip(target_tags, classes)
|
||||
}
|
||||
|
||||
def encoder(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> Optional[str]:
|
||||
tags = set(sound_event_annotation.tags)
|
||||
replacer = (
|
||||
build_replacer(replacement_rules) if replacement_rules else lambda x: x
|
||||
)
|
||||
|
||||
intersection = tags & target_tags
|
||||
def encoder(
|
||||
tags: Iterable[data.Tag],
|
||||
) -> Optional[str]:
|
||||
sanitized_tags = {replacer(tag) for tag in tags}
|
||||
|
||||
intersection = sanitized_tags & target_tags
|
||||
|
||||
if not intersection:
|
||||
return None
|
||||
@ -115,6 +146,12 @@ def filter_sound_event(
|
||||
return True
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: Path, field: Optional[str] = None
|
||||
) -> TargetConfig:
|
||||
return load_config(path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
DEFAULT_SPECIES_LIST = [
|
||||
"Barbastellus barbastellus",
|
||||
"Eptesicus serotinus",
|
||||
|
@ -1,10 +1,12 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train.augmentations import (
|
||||
add_echo,
|
||||
adjust_dataset_width,
|
||||
mix_examples,
|
||||
select_random_subclip,
|
||||
)
|
||||
@ -68,3 +70,67 @@ def test_selected_random_subclip_has_the_correct_width(
|
||||
subclip = select_random_subclip(original, width=100)
|
||||
|
||||
assert subclip["spectrogram"].shape[1] == 100
|
||||
|
||||
|
||||
def test_adjust_dataset_width():
|
||||
height = 128
|
||||
width = 512
|
||||
samplerate = 48_000
|
||||
|
||||
times = np.linspace(0, 1, width)
|
||||
|
||||
audio_times = np.linspace(0, 1, samplerate)
|
||||
frequency = np.linspace(0, 24_000, height)
|
||||
|
||||
width_subset = 356
|
||||
audio_width_subset = int(samplerate * width_subset / width)
|
||||
|
||||
times_subset = times[:width_subset]
|
||||
audio_times_subset = audio_times[:audio_width_subset]
|
||||
dimensions = ["width", "height"]
|
||||
class_names = [f"species_{i}" for i in range(17)]
|
||||
|
||||
spectrogram = np.random.random([height, width_subset])
|
||||
sizes = np.random.random([len(dimensions), height, width_subset])
|
||||
classes = np.random.random([len(class_names), height, width_subset])
|
||||
audio = np.random.random([int(samplerate * width_subset / width)])
|
||||
|
||||
dataset = xr.Dataset(
|
||||
data_vars={
|
||||
"audio": (("audio_time",), audio),
|
||||
"spectrogram": (("frequency", "time"), spectrogram),
|
||||
"sizes": (("dimension", "frequency", "time"), sizes),
|
||||
"classes": (("class", "frequency", "time"), classes),
|
||||
},
|
||||
coords={
|
||||
"audio_time": audio_times_subset,
|
||||
"time": times_subset,
|
||||
"frequency": frequency,
|
||||
"dimension": dimensions,
|
||||
"class": class_names,
|
||||
},
|
||||
)
|
||||
|
||||
adjusted = adjust_dataset_width(dataset, width=width)
|
||||
|
||||
# Spectrogram was adjusted correctly
|
||||
assert np.isclose(adjusted["spectrogram"].time, times).all()
|
||||
assert (adjusted["spectrogram"].frequency == frequency).all()
|
||||
|
||||
# Sizes was adjusted correctly
|
||||
assert np.isclose(adjusted["sizes"].time, times).all()
|
||||
assert (adjusted["sizes"].frequency == frequency).all()
|
||||
assert list(adjusted["sizes"].dimension.values) == dimensions
|
||||
|
||||
# Sizes was adjusted correctly
|
||||
assert np.isclose(adjusted["classes"].time, times).all()
|
||||
assert (adjusted["sizes"].frequency == frequency).all()
|
||||
assert list(adjusted["classes"]["class"].values) == class_names
|
||||
|
||||
# Audio time was adjusted corretly
|
||||
assert np.isclose(
|
||||
len(adjusted["audio"].audio_time), len(audio_times), atol=2
|
||||
)
|
||||
assert np.isclose(
|
||||
adjusted["audio"].audio_time[-1], audio_times[-1], atol=1e-3
|
||||
)
|
||||
|
37
uv.lock
generated
37
uv.lock
generated
@ -548,7 +548,7 @@ name = "click"
|
||||
version = "8.1.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
|
||||
wheels = [
|
||||
@ -630,6 +630,9 @@ name = "configobj"
|
||||
version = "5.0.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f5/c4/c7f9e41bc2e5f8eeae4a08a01c91b2aea3dfab40a3e14b25e87e7db8d501/configobj-5.0.9.tar.gz", hash = "sha256:03c881bbf23aa07bccf1b837005975993c4ab4427ba57f959afdd9d1a2386848", size = 101518 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/c4/0679472c60052c27efa612b4cd3ddd2a23e885dcdc73461781d2c802d39e/configobj-5.0.9-py2.py3-none-any.whl", hash = "sha256:1ba10c5b6ee16229c79a05047aeda2b55eb4e80d7c7d8ecf17ec1ca600c79882", size = 35615 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "contourpy"
|
||||
@ -1433,7 +1436,7 @@ name = "ipykernel"
|
||||
version = "6.29.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "appnope", marker = "platform_system == 'Darwin'" },
|
||||
{ name = "appnope", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "comm" },
|
||||
{ name = "debugpy" },
|
||||
{ name = "ipython" },
|
||||
@ -3175,6 +3178,7 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/b7/20c6f3c0b656fe609675d69bc135c03aac9e3865912444be6339207b6648/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76", size = 686712 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/11/d12dbf683471f888d354dac59593873c2b45feb193c5e3e0f2ebf85e68b9/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6", size = 663936 },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/14/4c268f5077db5c83f743ee1daeb236269fa8577133a5cfa49f8b382baf13/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd", size = 696580 },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/fc/8cd12f189c6405a4c1cf37bd633aa740a9538c8e40497c231072d0fef5cf/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a", size = 663393 },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/29/c0a017b704aaf3cbf704989785cd9c5d5b8ccec2dae6ac0c53833c84e677/ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da", size = 100326 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/65/fa39d74db4e2d0cd252355732d966a460a41cd01c6353b820a0952432839/ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28", size = 118079 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/8f/683c6ad562f558cbc4f7c029abcd9599148c51c54b5ef0f24f2638da9fbb/ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6", size = 132224 },
|
||||
@ -3183,6 +3187,7 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/86/29/88c2567bc893c84d88b4c48027367c3562ae69121d568e8a3f3a8d363f4d/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52", size = 703012 },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/46/879763c619b5470820f0cd6ca97d134771e502776bc2b844d2adb6e37753/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642", size = 704352 },
|
||||
{ url = "https://files.pythonhosted.org/packages/02/80/ece7e6034256a4186bbe50dee28cd032d816974941a6abf6a9d65e4228a7/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2", size = 737344 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/ca/e4106ac7e80efbabdf4bf91d3d32fc424e41418458251712f5672eada9ce/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3", size = 714498 },
|
||||
{ url = "https://files.pythonhosted.org/packages/67/58/b1f60a1d591b771298ffa0428237afb092c7f29ae23bad93420b1eb10703/ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4", size = 100205 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/4f/b52f634c9548a9291a70dfce26ca7ebce388235c93588a1068028ea23fcc/ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb", size = 118185 },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/41/e7a405afbdc26af961678474a55373e1b323605a4f5e2ddd4a80ea80f628/ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632", size = 133433 },
|
||||
@ -3191,6 +3196,7 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/52/a9/d39f3c5ada0a3bb2870d7db41901125dbe2434fa4f12ca8c5b83a42d7c53/ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd", size = 706497 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/fa/097e38135dadd9ac25aecf2a54be17ddf6e4c23e43d538492a90ab3d71c6/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31", size = 698042 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/d5/a659ca6f503b9379b930f13bc6b130c9f176469b73b9834296822a83a132/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680", size = 745831 },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/5d/36619b61ffa2429eeaefaab4f3374666adf36ad8ac6330d855848d7d36fd/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d", size = 715692 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/82/85cb92f15a4231c89b95dfe08b09eb6adca929ef7df7e17ab59902b6f589/ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5", size = 98777 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/8f/c3654f6f1ddb75daf3922c3d8fc6005b1ab56671ad56ffb874d908bfa668/ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4", size = 115523 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/46/ccdef7a84ad745c37cb3d9a81790f28fbc9adf9c237dba682017b123294e/ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987", size = 131834 },
|
||||
@ -3199,6 +3205,7 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/da/1c/23497017c554fc06ff5701b29355522cff850f626337fff35d9ab352cb18/ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7", size = 689072 },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/e6/f3d4ff3223f9ea49c3b7169ec0268e42bd49f87c70c0e3e853895e4a7ae2/ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285", size = 667091 },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/62/ead07043527642491e5011b143f44b81ef80f1025a96069b7210e0f2f0f3/ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed", size = 699111 },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/b3/fe4d84446f7e4887e3bea7ceff0a7df23790b5ed625f830e79ace88ebefb/ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7", size = 666365 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/b3/7feb99a00bfaa5c6868617bb7651308afde85e5a0b23cd187fe5de65feeb/ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12", size = 100863 },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/07/de635108684b7a5bb06e432b0930c5a04b6c59efe73bd966d8db3cc208f2/ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b", size = 118653 },
|
||||
]
|
||||
@ -3632,19 +3639,19 @@ dependencies = [
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "sympy" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
wheels = [
|
||||
@ -3758,7 +3765,7 @@ name = "tqdm"
|
||||
version = "4.67.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e8/4f/0153c21dc5779a49a0598c445b1978126b1344bab9ee71e53e44877e14e0/tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a", size = 169739 }
|
||||
wheels = [
|
||||
|
Loading…
Reference in New Issue
Block a user