mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
No commits in common. "98c6da6d42434ab5856b9bd27cb564496027b711" and "f63307757cb231866d9d75e386f428acf6b92a57" have entirely different histories.
98c6da6d42
...
f63307757c
@ -1,8 +0,0 @@
|
||||
[bumpversion]
|
||||
current_version = 1.1.1
|
||||
commit = True
|
||||
tag = True
|
||||
|
||||
[bumpversion:file:batdetect2/__init__.py]
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
31
.github/workflows/python-package.yml
vendored
31
.github/workflows/python-package.yml
vendored
@ -1,29 +1,34 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
||||
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
pip install .
|
||||
- name: Test with pytest
|
||||
run: uv run pytest
|
||||
run: |
|
||||
pytest
|
||||
|
13
.github/workflows/python-publish.yml
vendored
13
.github/workflows/python-publish.yml
vendored
@ -1,3 +1,11 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
@ -9,14 +17,15 @@ permissions:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.x"
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -103,11 +103,13 @@ experiments/*
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
|
||||
# Bump2version
|
||||
.bumpversion.cfg
|
||||
|
||||
# DO Include
|
||||
!batdetect2_notebook.ipynb
|
||||
!batdetect2/models/checkpoints/*.pth.tar
|
||||
!tests/data/*.wav
|
||||
!notebooks/*.ipynb
|
||||
!tests/data/**/*.wav
|
||||
notebooks/lightning_logs
|
||||
example_data/preprocessed
|
||||
|
@ -1,6 +1 @@
|
||||
import logging
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
__version__ = "1.1.1"
|
||||
__version__ = '1.0.8'
|
||||
|
@ -1,13 +1,9 @@
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.train import train
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
"detect",
|
||||
"data",
|
||||
"train",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,22 +0,0 @@
|
||||
BATDETECT_ASCII_ART = """ .
|
||||
=#%: .%%#
|
||||
:%%%: .%%%%.
|
||||
%%%%.-===::%%%%*
|
||||
=%%%%+++++++%%%#.
|
||||
-: .%%%#====+++#%%# .-
|
||||
.+***= . =++. : .=*+#%*= :***.
|
||||
=+****+++==:%+#=+% *##%%%%*=##*#**-=
|
||||
++***+**+=: ##.. +##%%########**++
|
||||
.++*****#*+- :*:++ ##%#%%%%%####**++
|
||||
.++***+**++++- :#%%%%%####*##***+=
|
||||
.+++***+#+++*########%%%##%#+*****++:
|
||||
.=++++++*+++##%##%%####%%##*:+****+=
|
||||
=++++++====*#%%#%###%%###- +***+++.
|
||||
.+*++++= =+==##########= :****++.
|
||||
=++*+:. .:=#####= .++**++-
|
||||
.****: . -+**++=
|
||||
*###= .****==
|
||||
.#*#- **#*:
|
||||
-### -*##.
|
||||
+*= *#*
|
||||
"""
|
@ -1,14 +1,18 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
]
|
||||
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
@ -21,4 +25,3 @@ BatDetect2 - Detection and Classification
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
@ -1,11 +1,14 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
@ -111,9 +114,10 @@ def detect(
|
||||
):
|
||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||
save_results_to_file(results, results_path)
|
||||
except (RuntimeError, ValueError, LookupError, EOFError) as err:
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
error_files.append(audio_file)
|
||||
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
|
||||
click.secho(f"Error processing file!: {err}", fg="red")
|
||||
raise err
|
||||
|
||||
click.echo(f"\nResults saved to: {ann_dir}")
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
__all__ = ["data"]
|
||||
|
||||
|
||||
@cli.group()
|
||||
def data(): ...
|
||||
|
||||
|
||||
@data.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
"--field",
|
||||
type=str,
|
||||
help="If the dataset info is in a nested field please specify here.",
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help="The base directory to which all recording and annotations paths are relative to.",
|
||||
)
|
||||
def summary(
|
||||
dataset_config: Path,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
base_dir = base_dir or Path.cwd()
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config, field=field, base_dir=base_dir
|
||||
)
|
||||
print(f"Number of annotated clips: {len(dataset.clip_annotations)}")
|
@ -1,188 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.preprocess import (
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
load_label_config,
|
||||
load_target_config,
|
||||
preprocess_annotations,
|
||||
)
|
||||
|
||||
__all__ = ["train"]
|
||||
|
||||
|
||||
@cli.group()
|
||||
def train(): ...
|
||||
|
||||
|
||||
@train.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.argument(
|
||||
"output",
|
||||
type=click.Path(),
|
||||
)
|
||||
@click.option(
|
||||
"--dataset-field",
|
||||
type=str,
|
||||
help=(
|
||||
"Specifies the key to access the dataset information within the "
|
||||
"dataset configuration file, if the information is nested inside a "
|
||||
"dictionary. If the dataset information is at the top level of the "
|
||||
"config file, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"The main directory where your audio recordings and annotation "
|
||||
"files are stored. This helps the program find your data, "
|
||||
"especially if the paths in your dataset configuration file "
|
||||
"are relative."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the preprocessing configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
"within the preprocessing configuration file, specify the key "
|
||||
"here to access them. If the preprocessing settings are at the "
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--label-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the label generation configuration file. This file "
|
||||
"contains settings for how to create labels from your "
|
||||
"annotations, which the model uses to learn."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--label-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the label generation settings are inside a nested dictionary "
|
||||
"within the label configuration file, specify the key here. If "
|
||||
"the settings are at the top level, leave this blank."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the training target configuration file. This file "
|
||||
"specifies what sounds the model should learn to predict."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the target settings are inside a nested dictionary "
|
||||
"within the target configuration file, specify the key here. "
|
||||
"If the settings are at the top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--force",
|
||||
is_flag=True,
|
||||
help=(
|
||||
"If a preprocessed file already exists, this option tells the "
|
||||
"program to overwrite it with the new preprocessed data. Use "
|
||||
"this if you want to re-do the preprocessing even if the files "
|
||||
"already exist."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum number of computer cores to use when processing "
|
||||
"your audio data. Using more cores can speed up the preprocessing, "
|
||||
"but don't use more than your computer has available. By default, "
|
||||
"the program will use all available cores."
|
||||
),
|
||||
)
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
base_dir: Optional[Path] = None,
|
||||
preprocess_config: Optional[Path] = None,
|
||||
target_config: Optional[Path] = None,
|
||||
label_config: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
num_workers: Optional[int] = None,
|
||||
target_config_field: Optional[str] = None,
|
||||
preprocess_config_field: Optional[str] = None,
|
||||
label_config_field: Optional[str] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
):
|
||||
output = Path(output)
|
||||
base_dir = base_dir or Path.cwd()
|
||||
|
||||
preprocess = (
|
||||
load_preprocessing_config(
|
||||
preprocess_config,
|
||||
field=preprocess_config_field,
|
||||
)
|
||||
if preprocess_config
|
||||
else None
|
||||
)
|
||||
|
||||
target = (
|
||||
load_target_config(
|
||||
target_config,
|
||||
field=target_config_field,
|
||||
)
|
||||
if target_config
|
||||
else None
|
||||
)
|
||||
|
||||
label = (
|
||||
load_label_config(
|
||||
label_config,
|
||||
field=label_config_field,
|
||||
)
|
||||
if label_config
|
||||
else None
|
||||
)
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
field=dataset_field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
if not output.exists():
|
||||
output.mkdir(parents=True)
|
||||
|
||||
preprocess_annotations(
|
||||
dataset.clip_annotations,
|
||||
output_dir=output,
|
||||
replace=force,
|
||||
preprocessing_config=preprocess,
|
||||
label_config=label,
|
||||
target_config=target,
|
||||
max_workers=num_workers,
|
||||
)
|
@ -1,151 +0,0 @@
|
||||
from batdetect2.preprocess import (
|
||||
AmplitudeScaleConfig,
|
||||
AudioConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
PreprocessingConfig,
|
||||
ResampleConfig,
|
||||
Scales,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||
from batdetect2.terms import TagInfo
|
||||
from batdetect2.train.preprocess import (
|
||||
HeatmapsConfig,
|
||||
TargetConfig,
|
||||
TrainPreprocessingConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_spectrogram_scale(scale: str) -> Scales:
|
||||
if scale == "pcen":
|
||||
return PcenScaleConfig()
|
||||
if scale == "log":
|
||||
return LogScaleConfig()
|
||||
return AmplitudeScaleConfig()
|
||||
|
||||
|
||||
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||
return PreprocessingConfig(
|
||||
audio=AudioConfig(
|
||||
resample=ResampleConfig(
|
||||
samplerate=params["target_samp_rate"],
|
||||
mode="poly",
|
||||
),
|
||||
scale=params["scale_raw_audio"],
|
||||
center=params["scale_raw_audio"],
|
||||
duration=None,
|
||||
),
|
||||
spectrogram=SpectrogramConfig(
|
||||
stft=STFTConfig(
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
window_fn="hann",
|
||||
),
|
||||
frequencies=FrequencyConfig(
|
||||
min_freq=params["min_freq"],
|
||||
max_freq=params["max_freq"],
|
||||
),
|
||||
scale=get_spectrogram_scale(params["spec_scale"]),
|
||||
denoise=params["denoise_spec_avg"],
|
||||
size=SpecSizeConfig(
|
||||
height=params["spec_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
),
|
||||
max_scale=params["max_scale_spec"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_training_preprocessing_config(
|
||||
params: dict,
|
||||
) -> TrainPreprocessingConfig:
|
||||
generic = params["generic_class"][0]
|
||||
preprocessing = get_preprocessing_config(params)
|
||||
|
||||
freq_bin_width, time_bin_width = get_spectrogram_resolution(
|
||||
preprocessing.spectrogram
|
||||
)
|
||||
|
||||
return TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing,
|
||||
target=TargetConfig(
|
||||
classes=[
|
||||
TagInfo(key="class", value=class_name, label=class_name)
|
||||
for class_name in params["class_names"]
|
||||
],
|
||||
generic_class=TagInfo(
|
||||
key="class",
|
||||
value=generic,
|
||||
label=generic,
|
||||
),
|
||||
include=[
|
||||
TagInfo(key="event", value=event)
|
||||
for event in params["events_of_interest"]
|
||||
],
|
||||
exclude=[
|
||||
TagInfo(key="class", value=value)
|
||||
for value in params["classes_to_ignore"]
|
||||
],
|
||||
),
|
||||
heatmaps=HeatmapsConfig(
|
||||
position="bottom-left",
|
||||
time_scale=1 / time_bin_width,
|
||||
frequency_scale=1 / freq_bin_width,
|
||||
sigma=params["target_sigma"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 'standardize_classs_names_ip',
|
||||
# 'convert_to_genus',
|
||||
# 'genus_mapping',
|
||||
# 'standardize_classs_names',
|
||||
# 'genus_names',
|
||||
|
||||
# ['data_dir',
|
||||
# 'ann_dir',
|
||||
# 'train_split',
|
||||
# 'model_name',
|
||||
# 'num_filters',
|
||||
# 'experiment',
|
||||
# 'model_file_name',
|
||||
# 'op_im_dir',
|
||||
# 'op_im_dir_test',
|
||||
# 'notes',
|
||||
# 'spec_divide_factor',
|
||||
# 'detection_overlap',
|
||||
# 'ignore_start_end',
|
||||
# 'detection_threshold',
|
||||
# 'nms_kernel_size',
|
||||
# 'nms_top_k_per_sec',
|
||||
# 'aug_prob',
|
||||
# 'augment_at_train',
|
||||
# 'augment_at_train_combine',
|
||||
# 'echo_max_delay',
|
||||
# 'stretch_squeeze_delta',
|
||||
# 'mask_max_time_perc',
|
||||
# 'mask_max_freq_perc',
|
||||
# 'spec_amp_scaling',
|
||||
# 'aug_sampling_rates',
|
||||
# 'train_loss',
|
||||
# 'det_loss_weight',
|
||||
# 'size_loss_weight',
|
||||
# 'class_loss_weight',
|
||||
# 'individual_loss_weight',
|
||||
# 'emb_dim',
|
||||
# 'lr',
|
||||
# 'batch_size',
|
||||
# 'num_workers',
|
||||
# 'num_epochs',
|
||||
# 'num_eval_epochs',
|
||||
# 'device',
|
||||
# 'save_test_image_during_train',
|
||||
# 'save_test_image_after_train',
|
||||
# 'train_sets',
|
||||
# 'test_sets',
|
||||
# 'class_inv_freq',
|
||||
# 'ip_height']
|
@ -1,35 +0,0 @@
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from soundevent.data import PathLike
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def get_object_field(obj: dict, field: str) -> Any:
|
||||
if "." not in field:
|
||||
return obj[field]
|
||||
|
||||
field, rest = field.split(".", 1)
|
||||
subobj = obj[field]
|
||||
return get_object_field(subobj, rest)
|
||||
|
||||
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T],
|
||||
field: Optional[str] = None,
|
||||
) -> T:
|
||||
with open(path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
if field:
|
||||
config = get_object_field(config, field)
|
||||
|
||||
return schema.model_validate(config)
|
@ -1,14 +0,0 @@
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.data import load_dataset, load_dataset_from_config
|
||||
from batdetect2.data.types import Dataset
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"Dataset",
|
||||
"load_annotated_dataset",
|
||||
"load_dataset",
|
||||
"load_dataset_from_config",
|
||||
]
|
@ -1,36 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotationFile",
|
||||
"AnnotationFormats",
|
||||
"BatDetect2AnnotationFile",
|
||||
"BatDetect2AnnotationFiles",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2AnnotationFiles(BaseConfig):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
path: Path
|
||||
|
||||
|
||||
class BatDetect2AnnotationFile(BaseConfig):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
path: Path
|
||||
|
||||
|
||||
class AOEFAnnotationFile(BaseConfig):
|
||||
format: Literal["aoef"] = "aoef"
|
||||
path: Path
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
BatDetect2AnnotationFiles,
|
||||
BatDetect2AnnotationFile,
|
||||
AOEFAnnotationFile,
|
||||
]
|
||||
|
||||
|
@ -1,55 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.aeof import (
|
||||
AOEFAnnotations,
|
||||
load_aoef_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.batdetect2_files import (
|
||||
BatDetect2FilesAnnotations,
|
||||
load_batdetect2_files_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.batdetect2_merged import (
|
||||
BatDetect2MergedAnnotations,
|
||||
load_batdetect2_merged_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
"load_annotated_dataset",
|
||||
"AnnotatedDataset",
|
||||
"AOEFAnnotations",
|
||||
"BatDetect2FilesAnnotations",
|
||||
"BatDetect2MergedAnnotations",
|
||||
"AnnotationFormats",
|
||||
]
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
]
|
||||
|
||||
|
||||
def load_annotated_dataset(
|
||||
dataset: AnnotatedDataset,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
if isinstance(dataset, AOEFAnnotations):
|
||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||
|
||||
if isinstance(dataset, BatDetect2MergedAnnotations):
|
||||
return load_batdetect2_merged_annotated_dataset(
|
||||
dataset, base_dir=base_dir
|
||||
)
|
||||
|
||||
if isinstance(dataset, BatDetect2FilesAnnotations):
|
||||
return load_batdetect2_files_annotated_dataset(
|
||||
dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
|
@ -1,37 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"load_aoef_annotated_dataset",
|
||||
]
|
||||
|
||||
|
||||
class AOEFAnnotations(AnnotatedDataset):
|
||||
format: Literal["aoef"] = "aoef"
|
||||
annotations_path: Path
|
||||
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
loaded = io.load(path, audio_dir=audio_dir)
|
||||
|
||||
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
|
||||
raise ValueError(
|
||||
f"The AOEF file at {path} does not contain a set of annotations"
|
||||
)
|
||||
|
||||
return loaded
|
@ -1,80 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
file_annotation_to_annotation_task,
|
||||
file_annotation_to_clip,
|
||||
file_annotation_to_clip_annotation,
|
||||
list_file_annotations,
|
||||
load_file_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_batdetect2_files_annotated_dataset",
|
||||
"BatDetect2FilesAnnotations",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2FilesAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
annotations_dir: Path
|
||||
|
||||
|
||||
def load_batdetect2_files_annotated_dataset(
|
||||
dataset: BatDetect2FilesAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationProject:
|
||||
"""Convert annotations to annotation project."""
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_dir
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
paths = list_file_annotations(path)
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
file_annotation = load_file_annotation(p)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(
|
||||
file_annotation,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(
|
||||
file_annotation_to_clip_annotation(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
file_annotation_to_annotation_task(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
@ -1,64 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
FileAnnotation,
|
||||
file_annotation_to_annotation_task,
|
||||
file_annotation_to_clip,
|
||||
file_annotation_to_clip_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2MergedAnnotations",
|
||||
"load_batdetect2_merged_annotated_dataset",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2MergedAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
annotations_path: Path
|
||||
|
||||
|
||||
def load_batdetect2_merged_annotated_dataset(
|
||||
dataset: BatDetect2MergedAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationProject:
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
content = json.loads(Path(path).read_text())
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for ann in content:
|
||||
try:
|
||||
ann = FileAnnotation.model_validate(ann)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||
tasks.append(file_annotation_to_annotation_task(ann, clip))
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
@ -1,304 +0,0 @@
|
||||
"""Compatibility functions between old and new data structures."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2 import types
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
]
|
||||
|
||||
SPECIES_TAG_KEY = "species"
|
||||
ECHOLOCATION_EVENT = "Echolocation"
|
||||
UNKNOWN_CLASS = "__UNKNOWN__"
|
||||
|
||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
|
||||
ClassFn = Callable[[data.Recording], int]
|
||||
|
||||
IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
||||
|
||||
|
||||
def get_recording_class_name(recording: data.Recording) -> str:
|
||||
"""Get the class name for a recording."""
|
||||
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
|
||||
if tag is None:
|
||||
return UNKNOWN_CLASS
|
||||
return tag.value
|
||||
|
||||
|
||||
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
|
||||
"""Get the notes for a ClipAnnotation."""
|
||||
all_notes = [
|
||||
*annotation.notes,
|
||||
*annotation.clip.recording.notes,
|
||||
]
|
||||
messages = [note.message for note in all_notes if note.message is not None]
|
||||
return "\n".join(messages)
|
||||
|
||||
|
||||
def convert_to_annotation_group(
|
||||
annotation: data.ClipAnnotation,
|
||||
class_mapper: ClassMapper,
|
||||
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||
class_fn: ClassFn = lambda _: 0,
|
||||
individual_fn: IndividualFn = lambda _: 0,
|
||||
) -> types.AudioLoaderAnnotationGroup:
|
||||
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
|
||||
recording = annotation.clip.recording
|
||||
|
||||
start_times = []
|
||||
end_times = []
|
||||
low_freqs = []
|
||||
high_freqs = []
|
||||
class_ids = []
|
||||
x_inds = []
|
||||
y_inds = []
|
||||
individual_ids = []
|
||||
annotations: List[types.Annotation] = []
|
||||
class_id_file = class_fn(recording)
|
||||
|
||||
for sound_event in annotation.sound_events:
|
||||
geometry = sound_event.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
class_id = class_mapper.transform(sound_event) or -1
|
||||
event = event_fn(sound_event) or ""
|
||||
individual_id = individual_fn(sound_event) or -1
|
||||
|
||||
start_times.append(start_time)
|
||||
end_times.append(end_time)
|
||||
low_freqs.append(low_freq)
|
||||
high_freqs.append(high_freq)
|
||||
class_ids.append(class_id)
|
||||
individual_ids.append(individual_id)
|
||||
|
||||
# NOTE: This will be computed later so we just put a placeholder
|
||||
# here for now.
|
||||
x_inds.append(0)
|
||||
y_inds.append(0)
|
||||
|
||||
annotations.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": 1.0,
|
||||
"det_prob": 1.0,
|
||||
"individual": "0",
|
||||
"event": event,
|
||||
"class_id": class_id, # type: ignore
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(recording.path),
|
||||
"duration": recording.duration,
|
||||
"issues": False,
|
||||
"file_path": str(recording.path),
|
||||
"time_exp": recording.time_expansion,
|
||||
"class_name": get_recording_class_name(recording),
|
||||
"notes": get_annotation_notes(annotation),
|
||||
"annotated": True,
|
||||
"start_times": np.array(start_times),
|
||||
"end_times": np.array(end_times),
|
||||
"low_freqs": np.array(low_freqs),
|
||||
"high_freqs": np.array(high_freqs),
|
||||
"class_ids": np.array(class_ids),
|
||||
"x_inds": np.array(x_inds),
|
||||
"y_inds": np.array(y_inds),
|
||||
"individual_ids": np.array(individual_ids),
|
||||
"annotation": annotations,
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
"""Annotation class to hold batdetect annotations."""
|
||||
|
||||
label: str = Field(alias="class")
|
||||
event: str
|
||||
individual: int = 0
|
||||
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
|
||||
|
||||
class FileAnnotation(BaseModel):
|
||||
"""FileAnnotation class to hold batdetect annotations for a file."""
|
||||
|
||||
id: str
|
||||
duration: float
|
||||
time_exp: float = 1
|
||||
|
||||
label: str = Field(alias="class_name")
|
||||
|
||||
annotation: List[Annotation]
|
||||
|
||||
annotated: bool = False
|
||||
issues: bool = False
|
||||
notes: str = ""
|
||||
|
||||
|
||||
def load_file_annotation(path: PathLike) -> FileAnnotation:
|
||||
"""Load annotation from batdetect format."""
|
||||
path = Path(path)
|
||||
return FileAnnotation.model_validate_json(path.read_text())
|
||||
|
||||
|
||||
def annotation_to_sound_event(
|
||||
annotation: Annotation,
|
||||
recording: data.Recording,
|
||||
label_key: str = "class",
|
||||
event_key: str = "event",
|
||||
individual_key: str = "individual",
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Convert annotation to sound event annotation."""
|
||||
sound_event = data.SoundEvent(
|
||||
uuid=uuid.uuid5(
|
||||
NAMESPACE,
|
||||
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
||||
),
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
annotation.start_time,
|
||||
annotation.low_freq,
|
||||
annotation.end_time,
|
||||
annotation.high_freq,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
return data.SoundEventAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=annotation.label,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(event_key),
|
||||
value=annotation.event,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(individual_key),
|
||||
value=str(annotation.individual),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
audio_dir = audio_dir or Path.cwd()
|
||||
|
||||
full_path = Path(audio_dir) / file_annotation.id
|
||||
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"File {full_path} not found.")
|
||||
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip_annotation(
|
||||
file_annotation: FileAnnotation,
|
||||
clip: data.Clip,
|
||||
label_key: str = "class",
|
||||
event_key: str = "event",
|
||||
individual_key: str = "individual",
|
||||
) -> data.ClipAnnotation:
|
||||
"""Convert file annotation to clip annotation."""
|
||||
notes = []
|
||||
if file_annotation.notes:
|
||||
notes.append(data.Note(message=file_annotation.notes))
|
||||
|
||||
return data.ClipAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key), value=file_annotation.label
|
||||
)
|
||||
],
|
||||
sound_events=[
|
||||
annotation_to_sound_event(
|
||||
annotation,
|
||||
clip.recording,
|
||||
label_key=label_key,
|
||||
event_key=event_key,
|
||||
individual_key=individual_key,
|
||||
)
|
||||
for annotation in file_annotation.annotation
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_annotation_task(
|
||||
file_annotation: FileAnnotation,
|
||||
clip: data.Clip,
|
||||
) -> data.AnnotationTask:
|
||||
status_badges = []
|
||||
|
||||
if file_annotation.issues:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.rejected)
|
||||
)
|
||||
elif file_annotation.annotated:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.completed)
|
||||
)
|
||||
|
||||
return data.AnnotationTask(
|
||||
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
|
||||
clip=clip,
|
||||
status_badges=status_badges,
|
||||
)
|
||||
|
||||
|
||||
def list_file_annotations(path: PathLike) -> List[Path]:
|
||||
"""List all annotations in a directory."""
|
||||
path = Path(path)
|
||||
return [file for file in path.glob("*.json")]
|
@ -1,41 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"BatDetect2MergedAnnotations",
|
||||
]
|
||||
|
||||
|
||||
class AnnotatedDataset(BaseConfig):
|
||||
"""Represents a single, cohesive source of audio recordings and annotations.
|
||||
|
||||
A source typically groups recordings originating from a specific context,
|
||||
such as a single project, site, deployment, or recordist. All audio files
|
||||
belonging to a source should be located within a single directory,
|
||||
specified by `audio_dir`.
|
||||
|
||||
Annotations associated with these recordings are defined by the
|
||||
`annotations` field, which supports various formats (e.g., AOEF files,
|
||||
specific CSV
|
||||
structures).
|
||||
Crucially, file paths referenced within the annotation data *must* be
|
||||
relative to the `audio_dir`. This ensures that the dataset definition
|
||||
remains portable across different systems and base directories.
|
||||
|
||||
Attributes:
|
||||
name: A unique identifier for this data source.
|
||||
description: Detailed information about the source, including recording
|
||||
methods, annotation procedures, equipment used, potential biases,
|
||||
or any important caveats for users.
|
||||
audio_dir: The file system path to the directory containing the audio
|
||||
recordings for this source.
|
||||
"""
|
||||
|
||||
name: str
|
||||
audio_dir: Path
|
||||
description: str = ""
|
||||
|
||||
|
@ -9,16 +9,15 @@ import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2 import types
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
"load_file_annotation",
|
||||
"annotation_to_sound_event",
|
||||
"load_annotation_project",
|
||||
]
|
||||
|
||||
SPECIES_TAG_KEY = "species"
|
||||
@ -196,30 +195,18 @@ def annotation_to_sound_event(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=annotation.label,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(event_key),
|
||||
value=annotation.event,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(individual_key),
|
||||
value=str(annotation.individual),
|
||||
),
|
||||
data.Tag(key=label_key, value=annotation.label),
|
||||
data.Tag(key=event_key, value=annotation.event),
|
||||
data.Tag(key=individual_key, value=str(annotation.individual)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
label_key: str = "class",
|
||||
audio_dir: PathLike = Path.cwd(),
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
audio_dir = audio_dir or Path.cwd()
|
||||
|
||||
full_path = Path(audio_dir) / file_annotation.id
|
||||
|
||||
if not full_path.exists():
|
||||
@ -228,12 +215,6 @@ def file_annotation_to_clip(
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
@ -260,11 +241,7 @@ def file_annotation_to_clip_annotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key), value=file_annotation.label
|
||||
)
|
||||
],
|
||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||
sound_events=[
|
||||
annotation_to_sound_event(
|
||||
annotation,
|
||||
@ -304,3 +281,52 @@ def list_file_annotations(path: PathLike) -> List[Path]:
|
||||
"""List all annotations in a directory."""
|
||||
path = Path(path)
|
||||
return [file for file in path.glob("*.json")]
|
||||
|
||||
|
||||
def load_annotation_project(
|
||||
path: PathLike,
|
||||
name: Optional[str] = None,
|
||||
audio_dir: PathLike = Path.cwd(),
|
||||
) -> data.AnnotationProject:
|
||||
"""Convert annotations to annotation project."""
|
||||
paths = list_file_annotations(path)
|
||||
|
||||
if name is None:
|
||||
name = str(path)
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
file_annotation = load_file_annotation(p)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(
|
||||
file_annotation,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(
|
||||
file_annotation_to_clip_annotation(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
file_annotation_to_annotation_task(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=name,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
@ -1,37 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
from batdetect2.data.annotations import load_annotated_dataset
|
||||
from batdetect2.data.types import Dataset
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"load_dataset_from_config",
|
||||
]
|
||||
|
||||
|
||||
def load_dataset(
|
||||
dataset: Dataset,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
clip_annotations = []
|
||||
for source in dataset.sources:
|
||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||
clip_annotations.extend(annotated_source.clip_annotations)
|
||||
return data.AnnotationSet(clip_annotations=clip_annotations)
|
||||
|
||||
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
config = load_config(
|
||||
path=path,
|
||||
schema=Dataset,
|
||||
field=field,
|
||||
)
|
||||
return load_dataset(config, base_dir=base_dir)
|
33
batdetect2/data/datasets.py
Normal file
33
batdetect2/data/datasets.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Callable, Generic, Iterable, List, TypeVar
|
||||
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
__all__ = [
|
||||
"ClipDataset",
|
||||
]
|
||||
|
||||
|
||||
E = TypeVar("E")
|
||||
|
||||
|
||||
class ClipDataset(Dataset, Generic[E]):
|
||||
clips: List[data.Clip]
|
||||
|
||||
transform: Callable[[data.Clip], E]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clips: Iterable[data.Clip],
|
||||
transform: Callable[[data.Clip], E],
|
||||
name: str = "ClipDataset",
|
||||
):
|
||||
self.clips = list(clips)
|
||||
self.transform = transform
|
||||
self.name = name
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
|
||||
def __getitem__(self, idx: int) -> E:
|
||||
return self.transform(self.clips[idx])
|
@ -1,43 +1,27 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, List, Optional, Sequence, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import arrays, data, geometry
|
||||
from soundevent import data, geometry, arrays
|
||||
from soundevent.geometry.operations import Positions
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
__all__ = [
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"ClassMapper",
|
||||
"generate_heatmaps",
|
||||
"load_label_config",
|
||||
]
|
||||
|
||||
|
||||
class HeatmapsConfig(BaseConfig):
|
||||
position: Positions = "bottom-left"
|
||||
sigma: float = 3.0
|
||||
time_scale: float = 1000.0
|
||||
frequency_scale: float = 1 / 859.375
|
||||
|
||||
|
||||
class LabelConfig(BaseConfig):
|
||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||
TARGET_SIGMA = 3.0
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: xr.DataArray,
|
||||
class_names: List[str],
|
||||
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
|
||||
target_sigma: float = 3.0,
|
||||
class_mapper: ClassMapper,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
position: Positions = "bottom-left",
|
||||
time_scale: float = 1000.0,
|
||||
frequency_scale: float = 1 / 859.375,
|
||||
dtype=np.float32,
|
||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
@ -47,13 +31,20 @@ def generate_heatmaps(
|
||||
"Spectrogram must have time and frequency dimensions."
|
||||
)
|
||||
|
||||
time_duration = arrays.get_dim_width(spec, dim="time")
|
||||
freq_bandwidth = arrays.get_dim_width(spec, dim="frequency")
|
||||
|
||||
# Compute the size factors
|
||||
time_scale = 1 / time_duration
|
||||
frequency_scale = 1 / freq_bandwidth
|
||||
|
||||
# Initialize heatmaps
|
||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||
class_heatmap = xr.DataArray(
|
||||
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
||||
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
||||
dims=["category", *spec.dims],
|
||||
coords={
|
||||
"category": [*class_names],
|
||||
"category": class_mapper.class_labels,
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
@ -66,8 +57,9 @@ def generate_heatmaps(
|
||||
},
|
||||
)
|
||||
|
||||
for sound_event_annotation in sound_events:
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
geom = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
continue
|
||||
|
||||
@ -75,29 +67,23 @@ def generate_heatmaps(
|
||||
time, frequency = geometry.get_geometry_point(geom, position=position)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
try:
|
||||
detection_heatmap = arrays.set_value_at_pos(
|
||||
detection_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
continue
|
||||
|
||||
# Set the size of the sound event at the position in the size heatmap
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(end_time - start_time) * time_scale,
|
||||
(high_freq - low_freq) * frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
size,
|
||||
@ -106,12 +92,14 @@ def generate_heatmaps(
|
||||
)
|
||||
|
||||
# Get the class name of the sound event
|
||||
class_name = encoder(sound_event_annotation.tags)
|
||||
class_name = class_mapper.transform(sound_event_annotation)
|
||||
|
||||
if class_name is None:
|
||||
# If the label is None skip the sound event
|
||||
continue
|
||||
|
||||
# Set 1.0 at the position and category of the sound event in the class
|
||||
# heatmap
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
@ -142,9 +130,3 @@ def generate_heatmaps(
|
||||
).fillna(0.0)
|
||||
|
||||
return detection_heatmap, class_heatmap, size_heatmap
|
||||
|
||||
|
||||
def load_label_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> LabelConfig:
|
||||
return load_config(path, schema=LabelConfig, field=field)
|
392
batdetect2/data/preprocessing.py
Normal file
392
batdetect2/data/preprocessing.py
Normal file
@ -0,0 +1,392 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import BaseModel, Field
|
||||
from scipy.signal import resample_poly
|
||||
from soundevent import audio, data, arrays
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
__all__ = [
|
||||
"PreprocessingConfig",
|
||||
"preprocess_audio_clip",
|
||||
]
|
||||
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
SCALE_RAW_AUDIO = False
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
DEFAULT_DURATION = 1
|
||||
SPEC_HEIGHT = 128
|
||||
SPEC_WIDTH = 256
|
||||
SPEC_SCALE = "pcen"
|
||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseModel):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
|
||||
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
||||
|
||||
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||
|
||||
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||
|
||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||
|
||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||
|
||||
spec_scale: str = Field(default=SPEC_SCALE)
|
||||
|
||||
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
||||
|
||||
max_scale_spec: bool = MAX_SCALE_SPEC
|
||||
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
spec_height: int = SPEC_HEIGHT
|
||||
|
||||
spec_time_period: float = SPEC_TIME_PERIOD
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
path: Union[str, Path],
|
||||
) -> "PreprocessingConfig":
|
||||
"""Load configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path
|
||||
Path to the configuration file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PreprocessingConfig
|
||||
The configuration loaded from the file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the configuration file does not exist.
|
||||
pydantic.ValidationError
|
||||
If the configuration file is invalid.
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Config file not found: {path}")
|
||||
|
||||
return cls.model_validate_json(path.read_text())
|
||||
|
||||
def to_file(self, path: Union[str, Path]) -> None:
|
||||
"""Save configuration to a file."""
|
||||
path = Path(path)
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
path.write_text(self.model_dump_json())
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
config: PreprocessingConfig = PreprocessingConfig(),
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
config
|
||||
Configuration for preprocessing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
|
||||
"""
|
||||
wav = load_clip_audio(
|
||||
clip,
|
||||
target_sampling_rate=config.target_samplerate,
|
||||
scale=config.scale_audio,
|
||||
)
|
||||
|
||||
spec = compute_spectrogram(
|
||||
wav,
|
||||
fft_win_length=config.fft_win_length,
|
||||
fft_overlap=config.fft_overlap,
|
||||
max_freq=config.max_freq,
|
||||
min_freq=config.min_freq,
|
||||
spec_scale=config.spec_scale,
|
||||
denoise_spec_avg=config.denoise_spec_avg,
|
||||
max_scale_spec=config.max_scale_spec,
|
||||
)
|
||||
|
||||
if config.duration is not None:
|
||||
spec = adjust_spec_duration(clip, spec, config.duration)
|
||||
|
||||
duration = arrays.get_dim_width(spec, dim="time")
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(np.ceil(duration / config.spec_time_period)),
|
||||
frequency=config.spec_height,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def adjust_spec_duration(
|
||||
clip: data.Clip,
|
||||
spec: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
current_duration = clip.end_time - clip.start_time
|
||||
|
||||
if current_duration == duration:
|
||||
return spec
|
||||
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
start=clip.start_time,
|
||||
stop=clip.start_time + duration,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
start=clip.start_time,
|
||||
stop=clip.start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = SCALE_RAW_AUDIO,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
||||
|
||||
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
||||
|
||||
if scale:
|
||||
wav = ops.center(wav)
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
step = arrays.get_dim_step(wav, dim="time")
|
||||
original_samplerate = int(1 / step)
|
||||
|
||||
if original_samplerate == target_samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
gcd = np.gcd(original_samplerate, target_samplerate)
|
||||
resampled = resample_poly(
|
||||
wav.values,
|
||||
target_samplerate // gcd,
|
||||
original_samplerate // gcd,
|
||||
axis=time_axis,
|
||||
)
|
||||
|
||||
resampled_times = np.linspace(
|
||||
start,
|
||||
stop + step,
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
resampled_times,
|
||||
samplerate=target_samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
)
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
||||
fft_overlap: float = FFT_OVERLAP,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
spec_scale: str = SPEC_SCALE,
|
||||
denoise_spec_avg: bool = True,
|
||||
max_scale_spec: bool = False,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
spec = gen_mag_spectrogram(
|
||||
wav,
|
||||
window_len=fft_win_length,
|
||||
overlap_perc=fft_overlap,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
).astype(dtype)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=spec_scale)
|
||||
|
||||
if denoise_spec_avg:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if max_scale_spec:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
|
||||
return spec.astype(dtype)
|
||||
|
||||
|
||||
def gen_mag_spectrogram(
|
||||
wave: xr.DataArray,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
hop_len = window_len * (1 - overlap_perc)
|
||||
nfft = int(window_len * sampling_rate)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
|
||||
# compute spec
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data,
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_len - hop_len),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_len,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def denoise_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: str = SPEC_SCALE,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
|
||||
if scale == "pcen":
|
||||
smoothing_constant = get_pcen_smoothing_constant(samplerate / 10)
|
||||
return audio.pcen(
|
||||
spec * (2**31),
|
||||
smooth=smoothing_constant,
|
||||
).astype(dtype)
|
||||
|
||||
if scale == "log":
|
||||
return log_scale(spec, dtype=dtype)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def log_scale(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / samplerate)
|
||||
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
)
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def get_pcen_smoothing_constant(
|
||||
sr: int,
|
||||
time_constant: float = 0.4,
|
||||
hop_length: int = 512,
|
||||
) -> float:
|
||||
t_frames = time_constant * sr / float(hop_length)
|
||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
@ -1,29 +0,0 @@
|
||||
from typing import Annotated, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data.annotations import AnnotationFormats
|
||||
|
||||
|
||||
class Dataset(BaseConfig):
|
||||
"""Represents a collection of one or more DatasetSources.
|
||||
|
||||
In the context of batdetect2, a Dataset aggregates multiple `DatasetSource`
|
||||
instances. It serves as the primary unit for defining data splits,
|
||||
typically used for model training, validation, or testing phases.
|
||||
|
||||
Attributes:
|
||||
name: A descriptive name for the overall dataset
|
||||
(e.g., "UK Training Set").
|
||||
description: A detailed explanation of the dataset's purpose,
|
||||
composition, how it was assembled, or any specific characteristics.
|
||||
sources: A list containing the `DatasetSource` objects included in this
|
||||
dataset.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sources: List[
|
||||
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||
]
|
@ -1,5 +1,4 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
@ -220,6 +219,7 @@ def compute_call_interval(
|
||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||
|
||||
|
||||
|
||||
# NOTE: The order of the features in this dictionary is important. The
|
||||
# features are extracted in this order and the order of the columns in the
|
||||
# output csv file is determined by this order. In order to avoid breaking
|
||||
|
@ -206,10 +206,7 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
|
@ -5,10 +5,7 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from batdetect2.train.legacy.train_utils import (
|
||||
get_genus_mapping,
|
||||
get_short_class_names,
|
||||
)
|
||||
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
|
||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -12,14 +12,15 @@ import pandas as pd
|
||||
import torch
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
import batdetect2.evaluate.legacy.evaluate_models as evl
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def get_blank_annotation(ip_str):
|
||||
|
||||
res = {}
|
||||
res["class_name"] = ""
|
||||
res["duration"] = -1
|
||||
@ -76,6 +77,7 @@ def create_genus_mapping(gt_test, preds, class_names):
|
||||
|
||||
|
||||
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
||||
|
||||
res, ann = get_blank_annotation("Generated by Tadarida")
|
||||
|
||||
# create the annotations in the correct format
|
||||
@ -118,6 +120,7 @@ def load_sonobat_meta(
|
||||
class_names,
|
||||
only_accepted_species=True,
|
||||
):
|
||||
|
||||
sp_dict = {}
|
||||
for ss in class_names:
|
||||
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
||||
@ -179,6 +182,7 @@ def load_sonobat_meta(
|
||||
|
||||
|
||||
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
# create the annotations in the correct format
|
||||
res, ann = get_blank_annotation("Generated by Sonobat")
|
||||
res_c = copy.deepcopy(res)
|
||||
@ -217,6 +221,7 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
|
||||
def bb_overlap(bb_g_in, bb_p_in):
|
||||
|
||||
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
||||
bb_g = [
|
||||
bb_g_in["start_time"],
|
||||
@ -460,6 +465,7 @@ def check_classes_in_train(gt_list, class_names):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"op_dir",
|
@ -8,10 +8,10 @@ import torch.utils.data
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.parameters as parameters
|
||||
import batdetect2.train.legacy.audio_dataloader as adl
|
||||
import batdetect2.train.legacy.train_model as tm
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.losses as losses
|
||||
import batdetect2.train.train_model as tm
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2 import types
|
||||
|
@ -1,92 +1,11 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
from batdetect2.models.feature_extractors import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
Net2DPlain,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"ClassifierHead",
|
||||
"ModelConfig",
|
||||
"ModelType",
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"build_architecture",
|
||||
"load_model_config",
|
||||
]
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
||||
|
||||
|
||||
def build_architecture(
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
||||
if config.name == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
||||
|
@ -1,185 +0,0 @@
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlock,
|
||||
Decoder,
|
||||
DownscalingLayer,
|
||||
Encoder,
|
||||
SelfAttention,
|
||||
UpscalingLayer,
|
||||
VerticalConv,
|
||||
)
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
]
|
||||
|
||||
|
||||
class Net2DPlain(BackboneModel):
|
||||
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
|
||||
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_height = input_height
|
||||
self.encoder_channels = tuple(encoder_channels)
|
||||
self.decoder_channels = tuple(decoder_channels)
|
||||
self.out_channels = out_channels
|
||||
|
||||
if len(encoder_channels) != len(decoder_channels):
|
||||
raise ValueError(
|
||||
f"Mismatched encoder and decoder channel lists. "
|
||||
f"The encoder has {len(encoder_channels)} channels "
|
||||
f"(implying {len(encoder_channels) - 1} layers), "
|
||||
f"while the decoder has {len(decoder_channels)} channels "
|
||||
f"(implying {len(decoder_channels) - 1} layers). "
|
||||
f"These lengths must be equal."
|
||||
)
|
||||
|
||||
self.divide_factor = 2 ** (len(encoder_channels) - 1)
|
||||
if self.input_height % self.divide_factor != 0:
|
||||
raise ValueError(
|
||||
f"Input height ({self.input_height}) must be divisible by "
|
||||
f"the divide factor ({self.divide_factor}). "
|
||||
f"This ensures proper upscaling after downscaling to recover "
|
||||
f"the original input height."
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
channels=encoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.downscaling_layer_type,
|
||||
)
|
||||
|
||||
self.conv_same_1 = ConvBlock(
|
||||
in_channels=encoder_channels[-1],
|
||||
out_channels=bottleneck_channels,
|
||||
)
|
||||
|
||||
# bottleneck
|
||||
self.conv_vert = VerticalConv(
|
||||
in_channels=bottleneck_channels,
|
||||
out_channels=bottleneck_channels,
|
||||
input_height=self.input_height // (2**self.encoder.depth),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
channels=decoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.upscaling_layer_type,
|
||||
)
|
||||
|
||||
self.conv_same_2 = ConvBlock(
|
||||
in_channels=decoder_channels[-1],
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
# bottleneck
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
|
||||
# decoder
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
# Restore original size
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.conv_same_2(x)
|
||||
|
||||
|
||||
class Net2DFast(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
):
|
||||
super().__init__(
|
||||
input_height=input_height,
|
||||
encoder_channels=encoder_channels,
|
||||
bottleneck_channels=bottleneck_channels,
|
||||
decoder_channels=decoder_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
# bottleneck
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
|
||||
# decoder
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
# Restore original size
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.conv_same_2(x)
|
||||
|
||||
|
||||
class Net2DFastNoAttn(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(Net2DFast):
|
||||
downscaling_layer_type = "ConvBlockDownStandard"
|
||||
upscaling_layer_type = "ConvBlockUpStandard"
|
||||
|
||||
|
||||
def pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
print(spec.shape)
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
|
||||
|
||||
def restore_pad(
|
||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||
) -> torch.Tensor:
|
||||
# Restore original size
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return x
|
@ -4,32 +4,18 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build
|
||||
complex neural network architectures.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
"SelfAttention",
|
||||
"ConvBlockDownCoordF",
|
||||
"ConvBlockDownStandard",
|
||||
"ConvBlockUpF",
|
||||
"ConvBlockUpStandard",
|
||||
"SelfAttention",
|
||||
"VerticalConv",
|
||||
"DownscalingLayer",
|
||||
"UpscalingLayer",
|
||||
]
|
||||
|
||||
|
||||
@ -39,21 +25,16 @@ class SelfAttention(nn.Module):
|
||||
This module implements self-attention mechanism.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
attention_channels: int,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
def __init__(self, ip_dim: int, att_dim: int):
|
||||
super().__init__()
|
||||
|
||||
# Note, does not encode position information (absolute or relative)
|
||||
self.temperature = temperature
|
||||
self.att_dim = attention_channels
|
||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
||||
# Note, does not encode position information (absolute or realtive)
|
||||
self.temperature = 1.0
|
||||
self.att_dim = att_dim
|
||||
self.key_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.val_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.que_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
@ -62,11 +43,11 @@ class SelfAttention(nn.Module):
|
||||
x, self.key_fun.weight.T
|
||||
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
query = torch.matmul(
|
||||
x, self.query_fun.weight.T
|
||||
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
x, self.que_fun.weight.T
|
||||
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
value = torch.matmul(
|
||||
x, self.value_fun.weight.T
|
||||
) + self.value_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
x, self.val_fun.weight.T
|
||||
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||
self.temperature * self.att_dim
|
||||
@ -82,66 +63,6 @@ class SelfAttention(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu_(self.conv_bn(self.conv(x)))
|
||||
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
"""Convolutional layer over full height.
|
||||
|
||||
This layer applies a convolution that captures information across the
|
||||
entire height of the input image. It uses a kernel with the same height as
|
||||
the input, effectively condensing the vertical information into a single
|
||||
output row.
|
||||
|
||||
More specifically:
|
||||
|
||||
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
||||
input channels, H is the image height, and W is the image width.
|
||||
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
||||
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
||||
convolution integrates information from all rows of the input.
|
||||
|
||||
This process effectively extracts features that span the full height of
|
||||
the input image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(input_height, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
"""Convolutional Block with Downsampling and Coord Feature.
|
||||
|
||||
@ -151,27 +72,27 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
kernel_size: int = 3,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
ip_height: int,
|
||||
k_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels + 1,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
in_chn + 1,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
@ -189,28 +110,26 @@ class ConvBlockDownStandard(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
in_chn,
|
||||
out_chn,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
in_chn,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.conv_bn(x), inplace=True)
|
||||
|
||||
|
||||
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlockUpF(nn.Module):
|
||||
@ -222,10 +141,10 @@ class ConvBlockUpF(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
kernel_size: int = 3,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
ip_height: int,
|
||||
k_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
@ -235,18 +154,15 @@ class ConvBlockUpF(nn.Module):
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, input_height * up_scale[0])[
|
||||
torch.linspace(-1, 1, ip_height * up_scale[0])[
|
||||
None, None, ..., None
|
||||
],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels + 1,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
op = F.interpolate(
|
||||
@ -273,9 +189,9 @@ class ConvBlockUpStandard(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
k_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
@ -284,12 +200,9 @@ class ConvBlockUpStandard(nn.Module):
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
in_chn, out_chn, kernel_size=k_size, padding=pad_size
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
op = F.interpolate(
|
||||
@ -304,143 +217,3 @@ class ConvBlockUpStandard(nn.Module):
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
||||
|
||||
|
||||
def build_downscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: DownscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockDownStandard":
|
||||
return ConvBlockDownStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockDownCoordF":
|
||||
return ConvBlockDownCoordF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid downscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||
)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (1, 32, 62, 128),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
||||
] = "ConvBlockDownStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_downscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height // (2**layer_num),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
outputs = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def build_upscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: UpscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockUpStandard":
|
||||
return ConvBlockUpStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockUpF":
|
||||
return ConvBlockUpF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid upscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
||||
)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (256, 62, 32, 32),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockUpStandard", "ConvBlockUpF"
|
||||
] = "ConvBlockUpStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
self.depth = len(self.channels) - 1
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_upscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height
|
||||
// (2 ** (self.depth - layer_num)),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
f"Incorrect number of residuals provided. "
|
||||
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||
f"but got {len(residuals)}."
|
||||
)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
x = layer(x + res)
|
||||
|
||||
return x
|
||||
|
@ -1,15 +0,0 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
139
batdetect2/models/detectors.py
Normal file
139
batdetect2/models/detectors.py
Normal file
@ -0,0 +1,139 @@
|
||||
from typing import Type
|
||||
|
||||
import pytorch_lightning as L
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from torch import nn, optim
|
||||
|
||||
from batdetect2.data.preprocessing import (
|
||||
preprocess_audio_clip,
|
||||
PreprocessingConfig,
|
||||
)
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
from batdetect2.models.feature_extractors import Net2DFast
|
||||
from batdetect2.models.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
|
||||
from batdetect2.train import losses
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
class_mapper: ClassMapper,
|
||||
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
|
||||
learning_rate: float = 1e-3,
|
||||
input_height: int = 128,
|
||||
num_features: int = 32,
|
||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.preprocessing_config = preprocessing_config
|
||||
self.postprocessing_config = postprocessing_config
|
||||
self.class_mapper = class_mapper
|
||||
self.learning_rate = learning_rate
|
||||
self.input_height = input_height
|
||||
self.num_features = num_features
|
||||
self.num_classes = class_mapper.num_classes
|
||||
|
||||
self.feature_extractor = feature_extractor_class(
|
||||
input_height=input_height,
|
||||
num_features=num_features,
|
||||
)
|
||||
|
||||
self.classifier = nn.Conv2d(
|
||||
self.feature_extractor.num_features // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
self.feature_extractor.num_features // 4,
|
||||
2,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.feature_extractor(spec)
|
||||
classification_logits = self.classifier(features)
|
||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
||||
return ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=self.bbox(features),
|
||||
class_probs=classification_probs[:, :-1],
|
||||
features=features,
|
||||
)
|
||||
|
||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
||||
return preprocess_audio_clip(
|
||||
clip,
|
||||
config=self.preprocessing_config,
|
||||
)
|
||||
|
||||
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
|
||||
spectrogram = self.compute_spectrogram(clip)
|
||||
return self.feature_extractor(
|
||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
|
||||
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
|
||||
spectrogram = self.compute_spectrogram(clip)
|
||||
spec_tensor = (
|
||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
outputs = self(spec_tensor)
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
[clip],
|
||||
class_mapper=self.class_mapper,
|
||||
config=self.postprocessing_config,
|
||||
)[0]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
) -> torch.Tensor:
|
||||
detection_loss = losses.focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
)
|
||||
|
||||
size_loss = losses.bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = losses.focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
valid_mask=valid_mask,
|
||||
)
|
||||
|
||||
return detection_loss + size_loss + classification_loss
|
||||
|
||||
def training_step( # type: ignore
|
||||
self,
|
||||
batch: TrainExample,
|
||||
):
|
||||
outputs = self.forward(batch.spec)
|
||||
loss = self.compute_loss(outputs, batch)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||
return [optimizer], [scheduler]
|
@ -1,15 +0,0 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
319
batdetect2/models/feature_extractors.py
Normal file
319
batdetect2/models/feature_extractors.py
Normal file
@ -0,0 +1,319 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlockDownCoordF,
|
||||
ConvBlockDownStandard,
|
||||
ConvBlockUpF,
|
||||
ConvBlockUpStandard,
|
||||
SelfAttention,
|
||||
)
|
||||
from batdetect2.models.typing import FeatureExtractorModel
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
]
|
||||
|
||||
|
||||
class Net2DFast(FeatureExtractorModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
# encoder
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
self.input_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
self.input_height // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
self.input_height // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
# bottleneck
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||
|
||||
# decoder
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
|
||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = (32 - h % 32) % 32
|
||||
w_pad = (32 - w % 32) % 32
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
# encoder
|
||||
spec, h_pad, w_pad = self.pad_adjust(spec)
|
||||
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
# bottleneck
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
# decoder
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
# Restore original size
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
|
||||
|
||||
class Net2DFastNoAttn(FeatureExtractorModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
self.input_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
self.input_height // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
self.input_height // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(FeatureExtractorModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownStandard(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownStandard(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownStandard(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpStandard(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpStandard(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpStandard(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
@ -1,51 +0,0 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
__all__ = ["ClassifierHead"]
|
||||
|
||||
|
||||
class Output(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
def __init__(self, num_classes: int, in_channels: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.classifier = nn.Conv2d(
|
||||
self.in_channels,
|
||||
# Add one to account for the background class
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> Output:
|
||||
logits = self.classifier(features)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
|
||||
return Output(
|
||||
detection=detection_probs,
|
||||
classification=probs[:, :-1],
|
||||
)
|
||||
|
||||
|
||||
class BBoxHead(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=2,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
return self.bbox(features)
|
@ -1,20 +1,19 @@
|
||||
"""Module for postprocessing model outputs."""
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
"postprocess_model_outputs",
|
||||
"PostprocessConfig",
|
||||
]
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
@ -22,7 +21,7 @@ DETECTION_THRESHOLD = 0.01
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
class PostprocessConfig(BaseModel):
|
||||
"""Configuration for postprocessing model outputs."""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
@ -32,29 +31,14 @@ class PostprocessConfig(BaseConfig):
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
detection_score: float
|
||||
class_scores: Dict[str, float]
|
||||
features: np.ndarray
|
||||
TagFunction = Callable[[int], List[data.Tag]]
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
classes: List[str],
|
||||
decoder: Callable[[str], List[data.Tag]],
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
class_mapper: ClassMapper,
|
||||
config: PostprocessConfig,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Postprocesses model outputs to generate clip predictions.
|
||||
|
||||
@ -84,9 +68,6 @@ def postprocess_model_outputs(
|
||||
ValueError
|
||||
If the number of predictions does not match the number of clips.
|
||||
"""
|
||||
|
||||
config = config or PostprocessConfig()
|
||||
|
||||
num_predictions = len(outputs.detection_probs)
|
||||
|
||||
if num_predictions == 0:
|
||||
@ -127,8 +108,7 @@ def postprocess_model_outputs(
|
||||
size_preds,
|
||||
class_probs,
|
||||
features,
|
||||
classes=classes,
|
||||
decoder=decoder,
|
||||
class_mapper=class_mapper,
|
||||
min_freq=config.min_freq,
|
||||
max_freq=config.max_freq,
|
||||
detection_threshold=config.detection_threshold,
|
||||
@ -144,82 +124,6 @@ def postprocess_model_outputs(
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_predictions_from_outputs(
|
||||
start: float,
|
||||
end: float,
|
||||
scores: torch.Tensor,
|
||||
y_pos: torch.Tensor,
|
||||
x_pos: torch.Tensor,
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
classes: List[str],
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
) -> List[RawPrediction]:
|
||||
_, freq_bins, time_bins = size_preds.shape
|
||||
|
||||
sorted_indices = torch.argsort(x_pos)
|
||||
valid_indices = sorted_indices[
|
||||
scores[sorted_indices] > detection_threshold
|
||||
]
|
||||
|
||||
scores = scores[valid_indices]
|
||||
x_pos = x_pos[valid_indices]
|
||||
y_pos = y_pos[valid_indices]
|
||||
|
||||
predictions: List[RawPrediction] = []
|
||||
for score, x, y in zip(scores, x_pos, y_pos):
|
||||
width, height = size_preds[:, y, x]
|
||||
class_prob = class_probs[:, y, x].detach().numpy()
|
||||
feats = features[:, y, x].detach().numpy()
|
||||
|
||||
start_time = np.interp(
|
||||
x.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
end_time = np.interp(
|
||||
x.item() + width.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
low_freq = np.interp(
|
||||
y.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
high_freq = np.interp(
|
||||
y.item() - height.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||
|
||||
predictions.append(
|
||||
RawPrediction(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
detection_score=score.item(),
|
||||
features=feats,
|
||||
class_scores={
|
||||
class_name: prob
|
||||
for class_name, prob in zip(classes, class_prob)
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_sound_events_from_outputs(
|
||||
clip: data.Clip,
|
||||
scores: torch.Tensor,
|
||||
@ -228,8 +132,7 @@ def compute_sound_events_from_outputs(
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
classes: List[str],
|
||||
decoder: Callable[[str], List[data.Tag]],
|
||||
class_mapper: ClassMapper,
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
@ -278,13 +181,12 @@ def compute_sound_events_from_outputs(
|
||||
predicted_tags: List[data.PredictedTag] = []
|
||||
|
||||
for label_id, class_score in enumerate(class_prob):
|
||||
class_name = classes[label_id]
|
||||
corresponding_tags = decoder(class_name)
|
||||
corresponding_tags = class_mapper.inverse_transform(label_id)
|
||||
predicted_tags.extend(
|
||||
[
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=max(min(class_score.item(), 1), 0),
|
||||
score=class_score.item(),
|
||||
)
|
||||
for tag in corresponding_tags
|
||||
]
|
||||
@ -305,7 +207,7 @@ def compute_sound_events_from_outputs(
|
||||
),
|
||||
features=[
|
||||
data.Feature(
|
||||
term=data.term_from_key(f"batdetect2_{i}"),
|
||||
name=f"batdetect2_{i}",
|
||||
value=value.item(),
|
||||
)
|
||||
for i, value in enumerate(feature)
|
||||
@ -315,7 +217,7 @@ def compute_sound_events_from_outputs(
|
||||
predictions.append(
|
||||
data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=max(min(score.item(), 1), 0),
|
||||
score=score.item(),
|
||||
tags=predicted_tags,
|
||||
)
|
||||
)
|
@ -1,12 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Tuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
"BackboneModel",
|
||||
"FeatureExtractorModel",
|
||||
]
|
||||
|
||||
|
||||
@ -41,31 +41,16 @@ class ModelOutput(NamedTuple):
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class BackboneModel(ABC, nn.Module):
|
||||
class FeatureExtractorModel(ABC, nn.Module):
|
||||
input_height: int
|
||||
"""Height of the input spectrogram."""
|
||||
|
||||
encoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the encoder. The length of this tuple determines the number of
|
||||
encoder layers."""
|
||||
|
||||
decoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the decoder. The length of this tuple determines the number of
|
||||
decoder layers."""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels in the bottleneck layer, which connects the
|
||||
encoder and decoder."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of channels in the final output feature map produced by the
|
||||
backbone model."""
|
||||
num_features: int
|
||||
"""Dimension of the feature tensor."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the model."""
|
||||
"""Forward pass of the encoder model."""
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
|
@ -1,181 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.models import (
|
||||
BBoxHead,
|
||||
ClassifierHead,
|
||||
ModelConfig,
|
||||
build_architecture,
|
||||
)
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_encoder,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class ModuleConfig(BaseConfig):
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
architecture: ModelConfig = Field(default_factory=ModelConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocessing: PostprocessConfig = Field(
|
||||
default_factory=PostprocessConfig
|
||||
)
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
config: ModuleConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ModuleConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config or ModuleConfig()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.backbone = build_architecture(self.config.architecture)
|
||||
|
||||
self.classifier = ClassifierHead(
|
||||
num_classes=len(self.config.targets.classes),
|
||||
in_channels=self.backbone.out_channels,
|
||||
)
|
||||
|
||||
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
||||
|
||||
conf = self.config.train.loss.classification
|
||||
self.class_weights = (
|
||||
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||
)
|
||||
|
||||
# Training targets
|
||||
self.class_names = get_class_names(self.config.targets.classes)
|
||||
self.encoder = build_encoder(
|
||||
self.config.targets.classes,
|
||||
replacement_rules=self.config.targets.replace,
|
||||
)
|
||||
self.decoder = build_decoder(self.config.targets.classes)
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
||||
size_preds = self.bbox(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=size_preds,
|
||||
class_probs=classification_probs,
|
||||
features=features,
|
||||
)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("train/loss/detection", losses.total, logger=True)
|
||||
self.log("train/loss/size", losses.total, logger=True)
|
||||
self.log("train/loss/classification", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||
outputs = self.forward(batch.spec)
|
||||
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("val/loss/detection", losses.total, logger=True)
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
clip_prediction = postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip_annotation.clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
matches = match_predictions_and_annotations(
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
)
|
||||
|
||||
self.validation_predictions.extend(matches)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.validation_predictions.clear()
|
||||
|
||||
def configure_optimizers(self):
|
||||
conf = self.config.train.optimizer
|
||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def process_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[Path] = None,
|
||||
) -> data.ClipPrediction:
|
||||
spec = preprocess_audio_clip(
|
||||
clip,
|
||||
config=self.config.preprocessing,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||
outputs = self.forward(tensor)
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
@ -2,10 +2,10 @@
|
||||
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import matplotlib.ticker as tick
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
import matplotlib.ticker as tick
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
@ -102,6 +102,7 @@ def spectrogram(
|
||||
return ax
|
||||
|
||||
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
dets: List[Annotation],
|
||||
|
@ -1,68 +0,0 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess.audio import (
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.preprocess.config import (
|
||||
PreprocessingConfig,
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
AmplitudeScaleConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
Scales,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
compute_spectrogram,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AmplitudeScaleConfig",
|
||||
"AudioConfig",
|
||||
"FrequencyConfig",
|
||||
"LogScaleConfig",
|
||||
"PcenScaleConfig",
|
||||
"PreprocessingConfig",
|
||||
"ResampleConfig",
|
||||
"STFTConfig",
|
||||
"Scales",
|
||||
"SpecSizeConfig",
|
||||
"SpectrogramConfig",
|
||||
"load_preprocessing_config",
|
||||
"preprocess_audio_clip",
|
||||
]
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
config
|
||||
Configuration for preprocessing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
|
||||
"""
|
||||
config = config or PreprocessingConfig()
|
||||
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
|
||||
return compute_spectrogram(wav, config=config.spectrogram)
|
@ -1,61 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: np.ndarray,
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
|
||||
return np.pad(
|
||||
array,
|
||||
pad,
|
||||
mode="constant",
|
||||
constant_values=value,
|
||||
)
|
||||
|
||||
|
||||
def make_width_divisible(
|
||||
array: np.ndarray,
|
||||
factor: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
width = array.shape[axis]
|
||||
|
||||
if width % factor == 0:
|
||||
return array
|
||||
|
||||
extra = (-width) % factor
|
||||
return extend_width(array, extra, axis=axis, value=value)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
array: np.ndarray,
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
current_width = array.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return array
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
array,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
)
|
||||
|
||||
slices = [
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return array[tuple(slices)]
|
@ -1,199 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import arrays, audio, data
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
SCALE_RAW_AUDIO = False
|
||||
DEFAULT_DURATION = None
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
mode: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
scale: bool = SCALE_RAW_AUDIO
|
||||
center: bool = True
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
recording = data.Recording.from_file(path)
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or AudioConfig()
|
||||
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
|
||||
)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
|
||||
if config.resample:
|
||||
wav = resample_audio(
|
||||
wav,
|
||||
samplerate=config.resample.samplerate,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if config.center:
|
||||
wav = ops.center(wav)
|
||||
|
||||
if config.scale:
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def adjust_audio_duration(
|
||||
wave: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
current_duration = end_time - start_time
|
||||
|
||||
if current_duration == duration:
|
||||
return wave
|
||||
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
mode: str = "poly",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
step = arrays.get_dim_step(wav, dim="time")
|
||||
original_samplerate = int(1 / step)
|
||||
|
||||
if original_samplerate == samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
if mode == "poly":
|
||||
resampled = resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
elif mode == "fourier":
|
||||
resampled = resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
times = np.linspace(
|
||||
start,
|
||||
stop + step,
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
times,
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: xr.DataArray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array.values,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: xr.DataArray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
ratio = sr_new / sr_orig
|
||||
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
|
@ -1,31 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
AudioConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
SpectrogramConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PreprocessingConfig",
|
||||
"load_preprocessing_config",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
@ -1,323 +0,0 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from soundevent import arrays, audio
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
max_freq: int = Field(default=120_000, gt=0)
|
||||
min_freq: int = Field(default=10_000, gt=0)
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
height: int = 128
|
||||
"""Height of the spectrogram in pixels. This value determines the
|
||||
number of frequency bands and corresponds to the vertical dimension
|
||||
of the spectrogram."""
|
||||
|
||||
resize_factor: Optional[float] = 0.5
|
||||
"""Factor by which to resize the spectrogram along the time axis.
|
||||
A value of 0.5 reduces the temporal dimension by half, while a
|
||||
value of 2.0 doubles it. If None, no resizing is performed."""
|
||||
|
||||
|
||||
class LogScaleConfig(BaseConfig):
|
||||
name: Literal["log"] = "log"
|
||||
|
||||
|
||||
class PcenScaleConfig(BaseConfig):
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
hop_length: int = 512
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class AmplitudeScaleConfig(BaseConfig):
|
||||
name: Literal["amplitude"] = "amplitude"
|
||||
|
||||
|
||||
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
scale: Scales = Field(
|
||||
default_factory=PcenScaleConfig,
|
||||
discriminator="name",
|
||||
)
|
||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
denoise: bool = True
|
||||
max_scale: bool = False
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
config: Optional[SpectrogramConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = crop_spectrogram_frequencies(
|
||||
spec,
|
||||
min_freq=config.frequencies.min_freq,
|
||||
max_freq=config.frequencies.max_freq,
|
||||
)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=config.scale)
|
||||
|
||||
if config.denoise:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if config.size:
|
||||
spec = resize_spectrogram(
|
||||
spec,
|
||||
height=config.size.height,
|
||||
resize_factor=config.size.resize_factor,
|
||||
)
|
||||
|
||||
if config.max_scale:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
|
||||
return spec.astype(dtype)
|
||||
|
||||
|
||||
def crop_spectrogram_frequencies(
|
||||
spec: xr.DataArray,
|
||||
min_freq: int = 10_000,
|
||||
max_freq: int = 120_000,
|
||||
) -> xr.DataArray:
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def stft(
|
||||
wave: xr.DataArray,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
nfft = int(window_duration * sampling_rate)
|
||||
noverlap = int(window_overlap * nfft)
|
||||
hop_len = nfft - noverlap
|
||||
hop_duration = hop_len / sampling_rate
|
||||
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data.astype(dtype),
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
window=window_fn,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_duration - hop_duration),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_duration,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: Scales,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if scale.name == "log":
|
||||
return scale_log(spec, dtype=dtype)
|
||||
|
||||
if scale.name == "pcen":
|
||||
return scale_pcen(
|
||||
spec,
|
||||
time_constant=scale.time_constant,
|
||||
hop_length=scale.hop_length,
|
||||
gain=scale.gain,
|
||||
power=scale.power,
|
||||
bias=scale.bias,
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def scale_pcen(
|
||||
spec: xr.DataArray,
|
||||
time_constant: float = 0.4,
|
||||
hop_length: int = 512,
|
||||
gain: float = 0.98,
|
||||
bias: float = 2,
|
||||
power: float = 0.5,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
return audio.pcen(
|
||||
spec * (2**31),
|
||||
smooth=smoothing_constant,
|
||||
gain=gain,
|
||||
bias=bias,
|
||||
power=power,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def resize_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
height: int = 128,
|
||||
resize_factor: Optional[float] = 0.5,
|
||||
) -> xr.DataArray:
|
||||
resize_factor = resize_factor or 1
|
||||
current_width = spec.sizes["time"]
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(resize_factor * current_width),
|
||||
frequency=height,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def adjust_spectrogram_width(
|
||||
spec: xr.DataArray,
|
||||
divide_factor: int = 32,
|
||||
time_period: float = 0.001,
|
||||
) -> xr.DataArray:
|
||||
time_width = spec.sizes["time"]
|
||||
|
||||
if time_width % divide_factor == 0:
|
||||
return spec
|
||||
|
||||
target_size = int(
|
||||
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
|
||||
)
|
||||
extra_duration = (target_size - time_width) * time_period
|
||||
_, stop = arrays.get_dim_range(spec, dim="time")
|
||||
resized = ops.extend_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
stop=stop + extra_duration,
|
||||
)
|
||||
return resized
|
||||
|
||||
|
||||
def duration_to_spec_width(
|
||||
duration: float,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
samples = int(duration * samplerate)
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
width = (samples - fft_len + hop_len) / hop_len
|
||||
return int(np.floor(width))
|
||||
|
||||
|
||||
def spec_width_to_samples(
|
||||
width: int,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
return width * hop_len + fft_len - hop_len
|
||||
|
||||
|
||||
def get_spectrogram_resolution(
|
||||
config: SpectrogramConfig,
|
||||
) -> tuple[float, float]:
|
||||
max_freq = config.frequencies.max_freq
|
||||
min_freq = config.frequencies.min_freq
|
||||
assert config.size is not None
|
||||
|
||||
spec_height = config.size.height
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
freq_bin_width = (max_freq - min_freq) / spec_height
|
||||
hop_duration = config.stft.window_duration * (
|
||||
1 - config.stft.window_overlap
|
||||
)
|
||||
return freq_bin_width, hop_duration / resize_factor
|
@ -1,76 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
pad = [
|
||||
[0, 0] if index != axis else [0, extra]
|
||||
for index in range(axis, dims)[::-1]
|
||||
]
|
||||
return F.pad(
|
||||
array,
|
||||
[x for y in pad for x in y],
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def make_width_divisible(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
factor: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
width = array.shape[axis]
|
||||
|
||||
if width % factor == 0:
|
||||
return array
|
||||
|
||||
extra = (-width) % factor
|
||||
return extend_width(array, extra, axis=axis, value=value)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
current_width = array.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return array
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
array,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
)
|
||||
|
||||
slices = [
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return array[tuple(slices)]
|
@ -1,88 +0,0 @@
|
||||
from inspect import getmembers
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from soundevent import data, terms
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"get_term_from_info",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
label: Optional[str]
|
||||
name: Optional[str]
|
||||
uri: Optional[str]
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
value: str
|
||||
term: Optional[TermInfo] = None
|
||||
key: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
name="soundevent:call_type",
|
||||
label="Call Type",
|
||||
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
||||
)
|
||||
|
||||
individual = data.Term(
|
||||
name="soundevent:individual",
|
||||
label="Individual",
|
||||
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
||||
)
|
||||
|
||||
|
||||
ALL_TERMS = [
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
call_type,
|
||||
individual,
|
||||
]
|
||||
|
||||
|
||||
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
||||
for term in ALL_TERMS:
|
||||
if term_info.name and term_info.name == term.name:
|
||||
return term
|
||||
|
||||
if term_info.label and term_info.label == term.label:
|
||||
return term
|
||||
|
||||
if term_info.uri and term_info.uri == term.uri:
|
||||
return term
|
||||
|
||||
if term_info.name is None:
|
||||
if term_info.label is None:
|
||||
raise ValueError("At least one of name or label must be provided.")
|
||||
|
||||
term_info.name = (
|
||||
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
||||
)
|
||||
|
||||
if term_info.label is None:
|
||||
term_info.label = term_info.name
|
||||
|
||||
return data.Term(
|
||||
name=term_info.name,
|
||||
label=term_info.label,
|
||||
uri=term_info.uri,
|
||||
definition="Unknown",
|
||||
)
|
||||
|
||||
|
||||
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||
if tag_info.term:
|
||||
term = get_term_from_info(tag_info.term)
|
||||
elif tag_info.key:
|
||||
term = data.term_from_key(tag_info.key)
|
||||
else:
|
||||
raise ValueError("Either term or key must be provided in tag info.")
|
||||
|
||||
return data.Tag(term=term, value=tag_info.value)
|
@ -1,48 +0,0 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
add_echo,
|
||||
augment_example,
|
||||
load_agumentation_config,
|
||||
mask_frequency,
|
||||
mask_time,
|
||||
mix_examples,
|
||||
scale_volume,
|
||||
select_subclip,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
SubclipConfig,
|
||||
TrainExample,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||
from batdetect2.train.preprocess import preprocess_annotations
|
||||
from batdetect2.train.targets import TargetConfig, load_target_config
|
||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"LabelConfig",
|
||||
"LabeledDataset",
|
||||
"SubclipConfig",
|
||||
"TargetConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"add_echo",
|
||||
"augment_example",
|
||||
"load_agumentation_config",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
"load_train_config",
|
||||
"load_trainer_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_examples",
|
||||
"preprocess_annotations",
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
]
|
941
batdetect2/train/audio_dataloader.py
Normal file
941
batdetect2/train/audio_dataloader.py
Normal file
@ -0,0 +1,941 @@
|
||||
"""Functions and dataloaders for training and testing the model."""
|
||||
|
||||
import copy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
|
||||
import batdetect2.utils.audio_utils as au
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AudioLoaderAnnotationGroup,
|
||||
AudioLoaderParameters,
|
||||
FileAnnotation,
|
||||
)
|
||||
|
||||
|
||||
def generate_gt_heatmaps(
|
||||
spec_op_shape: Tuple[int, int],
|
||||
sampling_rate: float,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
class_names: List[str],
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
resize_factor: float,
|
||||
target_sigma: float,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
|
||||
"""Generate ground truth heatmaps from annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec_op_shape : Tuple[int, int]
|
||||
Shape of the input spectrogram.
|
||||
sampling_rate : int
|
||||
Sampling rate of the input audio in Hz.
|
||||
ann : AnnotationGroup
|
||||
Dictionary containing the annotation information.
|
||||
params : HeatmapParameters
|
||||
Parameters controlling the generation of the heatmaps.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_2d_det : np.ndarray
|
||||
2D heatmap of the presence of an event.
|
||||
y_2d_size : np.ndarray
|
||||
2D heatmap of the size of the bounding box associated to event.
|
||||
y_2d_classes : np.ndarray
|
||||
3D array containing the ground-truth class probabilities for each
|
||||
pixel.
|
||||
ann_aug : AnnotationGroup
|
||||
A dictionary containing the annotation information of the
|
||||
annotations that are within the input spectrogram, augmented with
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(class_names)
|
||||
op_height = spec_op_shape[0]
|
||||
op_width = spec_op_shape[1]
|
||||
freq_per_bin = (max_freq - min_freq) / op_height
|
||||
|
||||
# start and end times
|
||||
x_pos_start = au.time_to_x_coords(
|
||||
ann["start_times"],
|
||||
sampling_rate,
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
)
|
||||
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
|
||||
x_pos_end = au.time_to_x_coords(
|
||||
ann["end_times"],
|
||||
sampling_rate,
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
)
|
||||
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
|
||||
|
||||
# location on y axis i.e. frequency
|
||||
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
||||
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
|
||||
# Only include annotations that are within the input spectrogram
|
||||
valid_inds = np.where(
|
||||
(x_pos_start >= 0)
|
||||
& (x_pos_start < op_width)
|
||||
& (y_pos_low >= 0)
|
||||
& (y_pos_low < (op_height - 1))
|
||||
)[0]
|
||||
|
||||
ann_aug: AudioLoaderAnnotationGroup = {
|
||||
**ann,
|
||||
"start_times": ann["start_times"][valid_inds],
|
||||
"end_times": ann["end_times"][valid_inds],
|
||||
"high_freqs": ann["high_freqs"][valid_inds],
|
||||
"low_freqs": ann["low_freqs"][valid_inds],
|
||||
"class_ids": ann["class_ids"][valid_inds],
|
||||
"individual_ids": ann["individual_ids"][valid_inds],
|
||||
"x_inds": x_pos_start[valid_inds],
|
||||
"y_inds": y_pos_low[valid_inds],
|
||||
}
|
||||
|
||||
# if the number of calls is only 1, then it is unique
|
||||
# TODO would be better if we found these unique calls at the merging stage
|
||||
if len(ann_aug["individual_ids"]) == 1:
|
||||
ann_aug["individual_ids"][0] = 0
|
||||
|
||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||
|
||||
# num classes and "background" class
|
||||
y_2d_classes: np.ndarray = np.zeros(
|
||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||
)
|
||||
|
||||
# create 2D ground truth heatmaps
|
||||
for ii in valid_inds:
|
||||
draw_gaussian(
|
||||
y_2d_det[0, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
target_sigma,
|
||||
)
|
||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||
|
||||
cls_id = ann["class_ids"][ii]
|
||||
if cls_id > -1:
|
||||
draw_gaussian(
|
||||
y_2d_classes[cls_id, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
target_sigma,
|
||||
)
|
||||
|
||||
# be careful as this will have a 1.0 places where we have event but
|
||||
# dont know gt class this will be masked in training anyway
|
||||
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
||||
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
||||
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
||||
|
||||
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
||||
|
||||
|
||||
def draw_gaussian(
|
||||
heatmap: np.ndarray,
|
||||
center: Tuple[int, int],
|
||||
sigmax: float,
|
||||
sigmay: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""Draw a 2D gaussian into the heatmap.
|
||||
|
||||
If the gaussian center is outside the heatmap, then the gaussian is not
|
||||
drawn.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
heatmap : np.ndarray
|
||||
The heatmap to draw into. Should be of shape (height, width).
|
||||
center : Tuple[int, int]
|
||||
The center of the gaussian in (x, y) format.
|
||||
sigmax : float
|
||||
The standard deviation of the gaussian in the x direction.
|
||||
sigmay : Optional[float], optional
|
||||
The standard deviation of the gaussian in the y direction. If None,
|
||||
then sigmay = sigmax, by default None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the gaussian was drawn, False if it was not (because
|
||||
the center was outside the heatmap).
|
||||
|
||||
|
||||
"""
|
||||
# center is (x, y)
|
||||
# this edits the heatmap inplace
|
||||
|
||||
if sigmay is None:
|
||||
sigmay = sigmax
|
||||
tmp_size = np.maximum(sigmax, sigmay) * 3
|
||||
mu_x = int(center[0] + 0.5)
|
||||
mu_y = int(center[1] + 0.5)
|
||||
w, h = heatmap.shape[0], heatmap.shape[1]
|
||||
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
||||
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
||||
|
||||
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
|
||||
return False
|
||||
|
||||
size = 2 * tmp_size + 1
|
||||
x = np.arange(0, size, 1, np.float32)
|
||||
y = x[:, np.newaxis]
|
||||
x0 = y0 = size // 2
|
||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
g = np.exp(
|
||||
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
|
||||
)
|
||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
||||
img_x = max(0, ul[0]), min(br[0], h)
|
||||
img_y = max(0, ul[1]), min(br[1], w)
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
|
||||
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
||||
"""Pad array with -1s."""
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
|
||||
|
||||
|
||||
def warp_spec_aug(
|
||||
spec: torch.Tensor,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
stretch_squeeze_delta: float,
|
||||
) -> torch.Tensor:
|
||||
"""Warp spectrogram by randomly stretching and squeezing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to warp.
|
||||
ann: AnnotationGroup
|
||||
Annotation group for the spectrogram. Must be provided to sync
|
||||
the start and stop times with the spectrogram after warping.
|
||||
stretch_squeeze_delta: float
|
||||
Maximum amount to stretch or squeeze the spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Warped spectrogram.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function modifies the annotation group in place.
|
||||
"""
|
||||
# Augment spectrogram by randomly stretch and squeezing
|
||||
# NOTE this also changes the start and stop time in place
|
||||
|
||||
delta = stretch_squeeze_delta
|
||||
op_size = (spec.shape[1], spec.shape[2])
|
||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||
|
||||
if resize_amt >= spec.shape[2]:
|
||||
spec_r = torch.cat(
|
||||
(
|
||||
spec,
|
||||
torch.zeros(
|
||||
(1, spec.shape[1], resize_amt - spec.shape[2]),
|
||||
dtype=spec.dtype,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
spec_r = spec[:, :, :resize_amt]
|
||||
|
||||
# Resize the spectrogram
|
||||
spec = F.interpolate(
|
||||
spec_r.unsqueeze(0),
|
||||
size=op_size,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# Update the start and stop times
|
||||
ann["start_times"] *= 1.0 / resize_fract_r
|
||||
ann["end_times"] *= 1.0 / resize_fract_r
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def mask_time_aug(
|
||||
spec: torch.Tensor,
|
||||
mask_max_time_perc: float,
|
||||
) -> torch.Tensor:
|
||||
"""Mask out random blocks of time.
|
||||
|
||||
Will randomly mask out a block of time in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_time_perc` of the total time.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
mask_max_time_perc: float
|
||||
Maximum percentage of time to mask out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out time blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.TimeMasking(
|
||||
int(spec.shape[1] * mask_max_time_perc)
|
||||
)
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mask_freq_aug(
|
||||
spec: torch.Tensor,
|
||||
mask_max_freq_perc: float,
|
||||
) -> torch.Tensor:
|
||||
"""Mask out random blocks of frequency.
|
||||
|
||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
mask_max_freq_perc: float
|
||||
Maximum percentage of frequency to mask out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out frequency blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.FrequencyMasking(
|
||||
int(spec.shape[1] * mask_max_freq_perc)
|
||||
)
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def scale_vol_aug(
|
||||
spec: torch.Tensor,
|
||||
spec_amp_scaling: float,
|
||||
) -> torch.Tensor:
|
||||
"""Scale the volume of the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to scale.
|
||||
spec_amp_scaling: float
|
||||
Maximum scaling factor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
"""
|
||||
return spec * np.random.random() * spec_amp_scaling
|
||||
|
||||
|
||||
def echo_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
echo_max_delay: float,
|
||||
) -> np.ndarray:
|
||||
"""Add echo to audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to add echo to.
|
||||
sampling_rate: float
|
||||
Sampling rate of the audio.
|
||||
echo_max_delay: float
|
||||
Maximum delay of the echo in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio with echo added.
|
||||
"""
|
||||
sample_offset = (
|
||||
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
# NOTE: This seems to be wrong, as the echo should be added to the
|
||||
# end of the audio, not the beginning.
|
||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||
return audio
|
||||
|
||||
|
||||
def resample_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
resize_factor: float,
|
||||
spec_divide_factor: float,
|
||||
spec_train_width: int,
|
||||
aug_sampling_rates: List[int],
|
||||
) -> Tuple[np.ndarray, float, float]:
|
||||
"""Resample audio augmentation.
|
||||
|
||||
Will resample the audio to a random sampling rate from the list of
|
||||
sampling rates in `aug_sampling_rates`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate: float
|
||||
Original sampling rate of the audio.
|
||||
fft_win_length: float
|
||||
Length of the FFT window in seconds.
|
||||
fft_overlap: float
|
||||
Amount of overlap between FFT windows.
|
||||
resize_factor: float
|
||||
Factor to resize the spectrogram by.
|
||||
spec_divide_factor: float
|
||||
Factor to divide the spectrogram by.
|
||||
spec_train_width: int
|
||||
Width of the spectrogram.
|
||||
aug_sampling_rates: List[int]
|
||||
List of sampling rates to resample to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate : float
|
||||
New sampling rate.
|
||||
duration : float
|
||||
Duration of the audio in seconds.
|
||||
"""
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(aug_sampling_rates)
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
audio = au.pad_audio(
|
||||
audio,
|
||||
sampling_rate,
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
resize_factor,
|
||||
spec_divide_factor,
|
||||
spec_train_width,
|
||||
)
|
||||
duration = audio.shape[0] / float(sampling_rate)
|
||||
return audio, sampling_rate, duration
|
||||
|
||||
|
||||
def resample_audio(
|
||||
num_samples: int,
|
||||
sampling_rate: float,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: float,
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Resample audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_samples: int
|
||||
Expected number of samples for the output audio.
|
||||
sampling_rate: float
|
||||
Original sampling rate of the audio.
|
||||
audio2: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate2: float
|
||||
Target sampling rate of the audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio2 : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate2 : float
|
||||
New sampling rate.
|
||||
"""
|
||||
# resample to target sampling rate
|
||||
if sampling_rate != sampling_rate2:
|
||||
audio2 = librosa.resample(
|
||||
audio2,
|
||||
orig_sr=sampling_rate2,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
sampling_rate2 = sampling_rate
|
||||
|
||||
# pad or trim to the correct length
|
||||
if audio2.shape[0] < num_samples:
|
||||
audio2 = np.hstack(
|
||||
(
|
||||
audio2,
|
||||
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
|
||||
)
|
||||
)
|
||||
elif audio2.shape[0] > num_samples:
|
||||
audio2 = audio2[:num_samples]
|
||||
|
||||
return audio2, sampling_rate2
|
||||
|
||||
|
||||
def combine_audio_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: float,
|
||||
ann2: AudioLoaderAnnotationGroup,
|
||||
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
|
||||
"""Combine two audio files.
|
||||
|
||||
Will combine two audio files by resampling them to the same sampling rate
|
||||
and then combining them with a random weight. The annotations will be
|
||||
combined by taking the union of the two sets of annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
First Audio to combine.
|
||||
sampling_rate: int
|
||||
Sampling rate of the first audio.
|
||||
ann: AnnotationGroup
|
||||
Annotations for the first audio.
|
||||
audio2: np.ndarray
|
||||
Second Audio to combine.
|
||||
sampling_rate2: int
|
||||
Sampling rate of the second audio.
|
||||
ann2: AnnotationGroup
|
||||
Annotations for the second audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Combined audio.
|
||||
ann : AnnotationGroup
|
||||
Combined annotations.
|
||||
"""
|
||||
# resample so they are the same
|
||||
audio2, sampling_rate2 = resample_audio(
|
||||
audio.shape[0],
|
||||
sampling_rate,
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
)
|
||||
|
||||
# # set mean and std to be the same
|
||||
# audio2 = (audio2 - audio2.mean())
|
||||
# audio2 = (audio2/audio2.std())*audio.std()
|
||||
# audio2 = audio2 + audio.mean()
|
||||
|
||||
if (
|
||||
ann.get("annotated", False)
|
||||
and (ann2.get("annotated", False))
|
||||
and (sampling_rate2 == sampling_rate)
|
||||
and (audio.shape[0] == audio2.shape[0])
|
||||
):
|
||||
comb_weight = 0.3 + np.random.random() * 0.4
|
||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
||||
for kk in ann.keys():
|
||||
# when combining calls from different files, assume they come
|
||||
# from different individuals
|
||||
if kk == "individual_ids":
|
||||
if (ann[kk] > -1).sum() > 0:
|
||||
ann2[kk][ann2[kk] > -1] += (
|
||||
np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
)
|
||||
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
||||
|
||||
return audio, ann
|
||||
|
||||
|
||||
def _prepare_annotation(
|
||||
annotation: Annotation,
|
||||
class_names: List[str],
|
||||
) -> Annotation:
|
||||
try:
|
||||
class_id = class_names.index(annotation["class"])
|
||||
except ValueError:
|
||||
class_id = -1
|
||||
|
||||
ann: Annotation = {
|
||||
**annotation,
|
||||
"class_id": class_id,
|
||||
}
|
||||
|
||||
if "individual" in ann:
|
||||
ann["individual"] = int(ann["individual"]) # type: ignore
|
||||
|
||||
return ann
|
||||
|
||||
|
||||
def _prepare_file_annotation(
|
||||
annotation: FileAnnotation,
|
||||
class_names: List[str],
|
||||
classes_to_ignore: List[str],
|
||||
) -> AudioLoaderAnnotationGroup:
|
||||
annotations = [
|
||||
_prepare_annotation(ann, class_names)
|
||||
for ann in annotation["annotation"]
|
||||
if ann["class"] not in classes_to_ignore
|
||||
]
|
||||
|
||||
try:
|
||||
class_id_file = class_names.index(annotation["class_name"])
|
||||
except ValueError:
|
||||
class_id_file = -1
|
||||
|
||||
ret: AudioLoaderAnnotationGroup = {
|
||||
"id": annotation["id"],
|
||||
"annotated": annotation["annotated"],
|
||||
"duration": annotation["duration"],
|
||||
"issues": annotation["issues"],
|
||||
"time_exp": annotation["time_exp"],
|
||||
"class_name": annotation["class_name"],
|
||||
"notes": annotation["notes"],
|
||||
"annotation": annotations,
|
||||
"start_times": np.array([ann["start_time"] for ann in annotations]),
|
||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
||||
"class_ids": np.array(
|
||||
[ann.get("class_id", -1) for ann in annotations]
|
||||
),
|
||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class AudioLoader(torch.utils.data.Dataset):
|
||||
"""Main AudioLoader for training and testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_anns_ip: List[FileAnnotation],
|
||||
params: AudioLoaderParameters,
|
||||
dataset_name: Optional[str] = None,
|
||||
is_train: bool = False,
|
||||
return_spec_for_viz: bool = False,
|
||||
):
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.return_spec_for_viz = return_spec_for_viz
|
||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
||||
_prepare_file_annotation(
|
||||
ann,
|
||||
params["class_names"],
|
||||
params["classes_to_ignore"],
|
||||
)
|
||||
for ann in data_anns_ip
|
||||
]
|
||||
|
||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||
self.max_num_anns = 2 * np.max(
|
||||
ann_cnt
|
||||
) # x2 because we may be combining files during training
|
||||
|
||||
print("\n")
|
||||
if dataset_name is not None:
|
||||
print("Dataset : " + dataset_name)
|
||||
if self.is_train:
|
||||
print("Split type : train")
|
||||
else:
|
||||
print("Split type : test")
|
||||
print("Num files : " + str(len(self.data_anns)))
|
||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
||||
|
||||
def get_file_and_anns(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
|
||||
"""Get an audio file and its annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int, optional
|
||||
Index of the file to be loaded. If None, a random file is chosen.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio_raw : np.ndarray
|
||||
Loaded audio file.
|
||||
sampling_rate : float
|
||||
Sampling rate of the audio file.
|
||||
duration : float
|
||||
Duration of the audio file in seconds.
|
||||
ann : AnnotationGroup
|
||||
AnnotationGroup object containing the annotations for the audio file.
|
||||
"""
|
||||
# if no file specified, choose random one
|
||||
if index is None:
|
||||
index = np.random.randint(0, len(self.data_anns))
|
||||
|
||||
audio_file = self.data_anns[index]["file_path"]
|
||||
sampling_rate, audio_raw = au.load_audio(
|
||||
audio_file,
|
||||
self.data_anns[index]["time_exp"],
|
||||
self.params["target_samp_rate"],
|
||||
self.params["scale_raw_audio"],
|
||||
)
|
||||
|
||||
# copy annotation
|
||||
ann = copy.deepcopy(self.data_anns[index])
|
||||
# ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann[kk] = self.data_anns[index][kk].copy()
|
||||
|
||||
# if train then grab a random crop
|
||||
if self.is_train:
|
||||
nfft = int(self.params["fft_win_length"] * sampling_rate)
|
||||
noverlap = int(self.params["fft_overlap"] * nfft)
|
||||
length_samples = (
|
||||
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
|
||||
)
|
||||
|
||||
if audio_raw.shape[0] - length_samples > 0:
|
||||
sample_crop = np.random.randint(
|
||||
audio_raw.shape[0] - length_samples
|
||||
)
|
||||
else:
|
||||
sample_crop = 0
|
||||
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
|
||||
ann["start_times"] = ann["start_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
ann["end_times"] = ann["end_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
|
||||
# pad audio
|
||||
if self.is_train:
|
||||
op_spec_target_size = self.params["spec_train_width"]
|
||||
else:
|
||||
op_spec_target_size = None
|
||||
audio_raw = au.pad_audio(
|
||||
audio_raw,
|
||||
sampling_rate,
|
||||
self.params["fft_win_length"],
|
||||
self.params["fft_overlap"],
|
||||
self.params["resize_factor"],
|
||||
self.params["spec_divide_factor"],
|
||||
op_spec_target_size,
|
||||
)
|
||||
duration = audio_raw.shape[0] / float(sampling_rate)
|
||||
|
||||
# sort based on time
|
||||
inds = np.argsort(ann["start_times"])
|
||||
for kk in ann.keys():
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = ann[kk][inds]
|
||||
|
||||
return audio_raw, sampling_rate, duration, ann
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Get an item from the dataset."""
|
||||
# load audio file
|
||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
||||
|
||||
# augment on raw audio
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
# augment - combine with random audio file
|
||||
if (
|
||||
self.params["augment_at_train_combine"]
|
||||
and np.random.random() < self.params["aug_prob"]
|
||||
):
|
||||
(
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
_,
|
||||
ann2,
|
||||
) = self.get_file_and_anns()
|
||||
audio, ann = combine_audio_aug(
|
||||
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
|
||||
)
|
||||
|
||||
# simulate echo by adding delayed copy of the file
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
audio = echo_aug(
|
||||
audio,
|
||||
sampling_rate,
|
||||
echo_max_delay=self.params["echo_max_delay"],
|
||||
)
|
||||
|
||||
# resample the audio
|
||||
# if np.random.random() < self.params["aug_prob"]:
|
||||
# audio, sampling_rate, duration = resample_aug(
|
||||
# audio, sampling_rate, self.params
|
||||
# )
|
||||
|
||||
# create spectrogram
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params=dict(
|
||||
fft_win_length=self.params["fft_win_length"],
|
||||
fft_overlap=self.params["fft_overlap"],
|
||||
max_freq=self.params["max_freq"],
|
||||
min_freq=self.params["min_freq"],
|
||||
spec_scale=self.params["spec_scale"],
|
||||
denoise_spec_avg=self.params["denoise_spec_avg"],
|
||||
max_scale_spec=self.params["max_scale_spec"],
|
||||
),
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(self.params["spec_height"] * rsf),
|
||||
int(spec.shape[1] * rsf),
|
||||
)
|
||||
|
||||
# resize the spec
|
||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
||||
spec = F.interpolate(
|
||||
spec,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# augment spectrogram
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = scale_vol_aug(
|
||||
spec,
|
||||
spec_amp_scaling=self.params["spec_amp_scaling"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = warp_spec_aug(
|
||||
spec,
|
||||
ann,
|
||||
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_time_aug(
|
||||
spec,
|
||||
mask_max_time_perc=self.params["mask_max_time_perc"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_freq_aug(
|
||||
spec,
|
||||
mask_max_freq_perc=self.params["mask_max_freq_perc"],
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
outputs["spec"] = spec
|
||||
if self.return_spec_for_viz:
|
||||
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
|
||||
0
|
||||
)
|
||||
|
||||
# create ground truth heatmaps
|
||||
(
|
||||
outputs["y_2d_det"],
|
||||
outputs["y_2d_size"],
|
||||
outputs["y_2d_classes"],
|
||||
ann_aug,
|
||||
) = generate_gt_heatmaps(
|
||||
spec_op_shape,
|
||||
sampling_rate,
|
||||
ann,
|
||||
class_names=self.params["class_names"],
|
||||
fft_win_length=self.params["fft_win_length"],
|
||||
fft_overlap=self.params["fft_overlap"],
|
||||
max_freq=self.params["max_freq"],
|
||||
min_freq=self.params["min_freq"],
|
||||
resize_factor=self.params["resize_factor"],
|
||||
target_sigma=self.params["target_sigma"],
|
||||
)
|
||||
|
||||
# hack to get around requirement that all vectors are the same length
|
||||
# in the output batch
|
||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
||||
outputs["is_valid"] = pad_aray(
|
||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
||||
)
|
||||
keys = [
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
"x_inds",
|
||||
"y_inds",
|
||||
"start_times",
|
||||
"end_times",
|
||||
"low_freqs",
|
||||
"high_freqs",
|
||||
]
|
||||
for kk in keys:
|
||||
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
|
||||
|
||||
# convert to pytorch
|
||||
for kk in outputs.keys():
|
||||
if type(outputs[kk]) != torch.Tensor:
|
||||
outputs[kk] = torch.from_numpy(outputs[kk])
|
||||
|
||||
# scalars
|
||||
outputs["class_id_file"] = ann["class_id_file"]
|
||||
outputs["annotated"] = ann["annotated"]
|
||||
outputs["duration"] = duration
|
||||
outputs["sampling_rate"] = sampling_rate
|
||||
outputs["file_id"] = index
|
||||
|
||||
return outputs
|
||||
|
||||
def __len__(self):
|
||||
"""Denotes the total number of samples."""
|
||||
return len(self.data_anns)
|
@ -1,214 +1,136 @@
|
||||
from typing import Callable, Optional, Union
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import arrays, data
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||
from batdetect2.preprocess.arrays import adjust_width
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"load_agumentation_config",
|
||||
"select_subclip",
|
||||
"mix_examples",
|
||||
"add_echo",
|
||||
"scale_volume",
|
||||
"warp_spectrogram",
|
||||
"mask_time",
|
||||
"mask_frequency",
|
||||
"augment_example",
|
||||
]
|
||||
AUGMENTATION_PROBABILITY = 0.2
|
||||
MAX_DELAY = 0.005
|
||||
STRETCH_SQUEEZE_DELTA = 0.04
|
||||
MASK_MAX_TIME_PERC: float = 0.05
|
||||
MASK_MAX_FREQ_PERC: float = 0.10
|
||||
|
||||
|
||||
class BaseAugmentationConfig(BaseConfig):
|
||||
enable: bool = True
|
||||
probability: float = 0.2
|
||||
def maybe_apply(
|
||||
augmentation: Callable,
|
||||
prob: float = AUGMENTATION_PROBABILITY,
|
||||
) -> Callable:
|
||||
"""Apply an augmentation with a given probability."""
|
||||
|
||||
@wraps(augmentation)
|
||||
def _augmentation(x):
|
||||
if np.random.rand() > prob:
|
||||
return x
|
||||
return augmentation(x)
|
||||
|
||||
return _augmentation
|
||||
|
||||
|
||||
def select_subclip(
|
||||
example: xr.Dataset,
|
||||
start_time: Optional[float] = None,
|
||||
def select_random_subclip(
|
||||
train_example: xr.Dataset,
|
||||
duration: Optional[float] = None,
|
||||
width: Optional[int] = None,
|
||||
random: bool = False,
|
||||
proportion: float = 0.9,
|
||||
) -> xr.Dataset:
|
||||
"""Select a random subclip from a clip."""
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||
|
||||
if width is None:
|
||||
if duration is None:
|
||||
raise ValueError("Either duration or width must be provided")
|
||||
time_coords = train_example.coords["time"]
|
||||
|
||||
width = int(np.floor(duration / step))
|
||||
start_time = time_coords.attrs.get("min", time_coords.min())
|
||||
end_time = time_coords.attrs.get("max", time_coords.max())
|
||||
|
||||
if duration is None:
|
||||
duration = width * step
|
||||
duration = (end_time - start_time) * proportion
|
||||
|
||||
if start_time is None:
|
||||
if random:
|
||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||
else:
|
||||
start_time = start
|
||||
|
||||
if start_time + duration > stop:
|
||||
return example
|
||||
|
||||
start_index = arrays.get_coord_index(
|
||||
example, # type: ignore
|
||||
"time",
|
||||
start_time,
|
||||
)
|
||||
|
||||
end_index = start_index + width - 1
|
||||
|
||||
start_time = example.time.values[start_index]
|
||||
end_time = example.time.values[end_index]
|
||||
|
||||
return example.sel(
|
||||
time=slice(start_time, end_time),
|
||||
audio_time=slice(start_time, end_time + step),
|
||||
)
|
||||
start_time = np.random.uniform(start_time, end_time - duration)
|
||||
return train_example.sel(time=slice(start_time, start_time + duration))
|
||||
|
||||
|
||||
class MixAugmentationConfig(BaseAugmentationConfig):
|
||||
min_weight: float = 0.3
|
||||
max_weight: float = 0.7
|
||||
|
||||
|
||||
def mix_examples(
|
||||
example: xr.Dataset,
|
||||
other: xr.Dataset,
|
||||
weight: Optional[float] = None,
|
||||
min_weight: float = 0.3,
|
||||
max_weight: float = 0.7,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
def combine_audio(
|
||||
audio1: xr.DataArray,
|
||||
audio2: xr.DataArray,
|
||||
alpha: Optional[float] = None,
|
||||
min_alpha: float = 0.3,
|
||||
max_alpha: float = 0.7,
|
||||
) -> xr.DataArray:
|
||||
"""Combine two audio clips."""
|
||||
config = config or PreprocessingConfig()
|
||||
|
||||
if weight is None:
|
||||
weight = np.random.uniform(min_weight, max_weight)
|
||||
if alpha is None:
|
||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||
|
||||
audio1 = example["audio"]
|
||||
audio2 = adjust_width(other["audio"].values, len(audio1))
|
||||
|
||||
combined = weight * audio1 + (1 - weight) * audio2
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
combined.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
).data
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
# simplification in the subclip process: It doesn't account for the
|
||||
# spectrogram parameters to precisely determine the corresponding audio
|
||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||
# as needed.
|
||||
previous_width = len(example["time"])
|
||||
spectrogram = adjust_width(spectrogram, previous_width)
|
||||
|
||||
detection_heatmap = xr.apply_ufunc(
|
||||
np.maximum,
|
||||
example["detection"],
|
||||
adjust_width(other["detection"].values, previous_width),
|
||||
)
|
||||
|
||||
class_heatmap = xr.apply_ufunc(
|
||||
np.maximum,
|
||||
example["class"],
|
||||
adjust_width(other["class"].values, previous_width),
|
||||
)
|
||||
|
||||
size_heatmap = example["size"] + adjust_width(
|
||||
other["size"].values, previous_width
|
||||
)
|
||||
|
||||
return xr.Dataset(
|
||||
{
|
||||
"audio": combined,
|
||||
"spectrogram": xr.DataArray(
|
||||
data=spectrogram,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
"size": size_heatmap,
|
||||
},
|
||||
attrs=example.attrs,
|
||||
)
|
||||
return alpha * audio1 + (1 - alpha) * audio2.data
|
||||
|
||||
|
||||
class EchoAugmentationConfig(BaseAugmentationConfig):
|
||||
max_delay: float = 0.005
|
||||
min_weight: float = 0.0
|
||||
max_weight: float = 1.0
|
||||
# def random_mix(
|
||||
# audio: xr.DataArray,
|
||||
# clip: data.ClipAnnotation,
|
||||
# provider: Optional[ClipProvider] = None,
|
||||
# alpha: Optional[float] = None,
|
||||
# min_alpha: float = 0.3,
|
||||
# max_alpha: float = 0.7,
|
||||
# join_annotations: bool = True,
|
||||
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
# """Mix two audio clips."""
|
||||
# if provider is None:
|
||||
# raise ValueError("No audio provider given.")
|
||||
#
|
||||
# try:
|
||||
# other_audio, other_clip = provider(clip)
|
||||
# except (StopIteration, ValueError):
|
||||
# raise ValueError("No more audio sources available.")
|
||||
#
|
||||
# new_audio = combine_audio(
|
||||
# audio,
|
||||
# other_audio,
|
||||
# alpha=alpha,
|
||||
# min_alpha=min_alpha,
|
||||
# max_alpha=max_alpha,
|
||||
# )
|
||||
#
|
||||
# if join_annotations:
|
||||
# clip = clip.model_copy(
|
||||
# update=dict(
|
||||
# sound_events=clip.sound_events + other_clip.sound_events,
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# return new_audio, clip
|
||||
|
||||
|
||||
def add_echo(
|
||||
example: xr.Dataset,
|
||||
train_example: xr.Dataset,
|
||||
delay: Optional[float] = None,
|
||||
weight: Optional[float] = None,
|
||||
min_weight: float = 0.1,
|
||||
max_weight: float = 1.0,
|
||||
max_delay: float = 0.005,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
alpha: Optional[float] = None,
|
||||
min_alpha: float = 0.0,
|
||||
max_alpha: float = 1.0,
|
||||
max_delay: float = MAX_DELAY,
|
||||
) -> xr.Dataset:
|
||||
"""Add a delay to the audio."""
|
||||
config = config or PreprocessingConfig()
|
||||
|
||||
if delay is None:
|
||||
delay = np.random.uniform(0, max_delay)
|
||||
|
||||
if weight is None:
|
||||
weight = np.random.uniform(min_weight, max_weight)
|
||||
if alpha is None:
|
||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||
|
||||
audio = example["audio"]
|
||||
step = arrays.get_dim_step(audio, "audio_time")
|
||||
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||
audio = audio + weight * audio_delay
|
||||
spec = train_example["spectrogram"]
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
audio.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
).data
|
||||
time_coords = spec.coords["time"]
|
||||
start_time = time_coords.attrs["min"]
|
||||
end_time = time_coords.attrs["max"]
|
||||
step = (end_time - start_time) / time_coords.size
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
# simplification in the subclip process: It doesn't account for the
|
||||
# spectrogram parameters to precisely determine the corresponding audio
|
||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||
# as needed.
|
||||
spectrogram = adjust_width(
|
||||
spectrogram,
|
||||
example["spectrogram"].sizes["time"],
|
||||
)
|
||||
spec_delay = spec.shift(time=int(delay / step), fill_value=0)
|
||||
|
||||
return example.assign(
|
||||
audio=audio,
|
||||
spectrogram=xr.DataArray(
|
||||
data=spectrogram,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class VolumeAugmentationConfig(BaseAugmentationConfig):
|
||||
min_scaling: float = 0.0
|
||||
max_scaling: float = 2.0
|
||||
return train_example.assign(spectrogram=spec + alpha * spec_delay)
|
||||
|
||||
|
||||
def scale_volume(
|
||||
example: xr.Dataset,
|
||||
train_example: xr.Dataset,
|
||||
factor: Optional[float] = None,
|
||||
max_scaling: float = 2,
|
||||
min_scaling: float = 0,
|
||||
@ -217,227 +139,106 @@ def scale_volume(
|
||||
if factor is None:
|
||||
factor = np.random.uniform(min_scaling, max_scaling)
|
||||
|
||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||
|
||||
|
||||
class WarpAugmentationConfig(BaseAugmentationConfig):
|
||||
delta: float = 0.04
|
||||
return train_example.assign(
|
||||
spectrogram=train_example["spectrogram"] * factor
|
||||
)
|
||||
|
||||
|
||||
def warp_spectrogram(
|
||||
example: xr.Dataset,
|
||||
train_example: xr.Dataset,
|
||||
factor: Optional[float] = None,
|
||||
delta: float = 0.04,
|
||||
delta: float = STRETCH_SQUEEZE_DELTA,
|
||||
) -> xr.Dataset:
|
||||
"""Warp a spectrogram."""
|
||||
if factor is None:
|
||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||
|
||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||
time_coords = train_example.coords["time"]
|
||||
start_time = time_coords.attrs["min"]
|
||||
end_time = time_coords.attrs["max"]
|
||||
duration = end_time - start_time
|
||||
|
||||
new_time = np.linspace(
|
||||
start_time,
|
||||
start_time + duration * factor,
|
||||
example.time.size,
|
||||
train_example.time.size,
|
||||
)
|
||||
|
||||
spectrogram = (
|
||||
example["spectrogram"]
|
||||
.interp(
|
||||
coords={"time": new_time},
|
||||
method="linear",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
.clip(min=0)
|
||||
)
|
||||
|
||||
detection = example["detection"].interp(
|
||||
time=new_time,
|
||||
method="nearest",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
|
||||
classification = example["class"].interp(
|
||||
time=new_time,
|
||||
method="nearest",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
|
||||
size = example["size"].interp(
|
||||
time=new_time,
|
||||
method="nearest",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
|
||||
return example.assign(
|
||||
{
|
||||
"time": new_time,
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection,
|
||||
"class": classification,
|
||||
"size": size,
|
||||
}
|
||||
)
|
||||
return train_example.interp(time=new_time)
|
||||
|
||||
|
||||
def mask_axis(
|
||||
array: xr.DataArray,
|
||||
train_example: xr.Dataset,
|
||||
dim: str,
|
||||
start: float,
|
||||
end: float,
|
||||
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
|
||||
) -> xr.DataArray:
|
||||
if dim not in array.dims:
|
||||
mask_all: bool = False,
|
||||
mask_value: float = 0,
|
||||
) -> xr.Dataset:
|
||||
if dim not in train_example.dims:
|
||||
raise ValueError(f"Axis {dim} not found in array")
|
||||
|
||||
coord = array.coords[dim]
|
||||
coord = train_example.coords[dim]
|
||||
condition = (coord < start) | (coord > end)
|
||||
|
||||
if callable(mask_value):
|
||||
mask_value = mask_value(array)
|
||||
if mask_all:
|
||||
return train_example.where(condition, other=mask_value)
|
||||
|
||||
return array.where(condition, other=mask_value)
|
||||
|
||||
|
||||
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
max_perc: float = 0.05
|
||||
max_masks: int = 3
|
||||
return train_example.assign(
|
||||
spectrogram=train_example.spectrogram.where(
|
||||
condition, other=mask_value
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def mask_time(
|
||||
example: xr.Dataset,
|
||||
max_perc: float = 0.05,
|
||||
max_mask: int = 3,
|
||||
train_example: xr.Dataset,
|
||||
max_time_mask: float = MASK_MAX_TIME_PERC,
|
||||
max_num_masks: int = 3,
|
||||
) -> xr.Dataset:
|
||||
"""Mask a random section of the time axis."""
|
||||
num_masks = np.random.randint(1, max_mask + 1)
|
||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||
|
||||
spectrogram = example["spectrogram"]
|
||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||
|
||||
time_coord = train_example.coords["time"]
|
||||
start_time = time_coord.attrs.get("min", time_coord.min())
|
||||
end_time = time_coord.attrs.get("max", time_coord.max())
|
||||
|
||||
for _ in range(num_masks):
|
||||
mask_size = np.random.uniform(0, max_perc) * (end_time - start_time)
|
||||
mask_size = np.random.uniform(0, max_time_mask)
|
||||
start = np.random.uniform(start_time, end_time - mask_size)
|
||||
end = start + mask_size
|
||||
spectrogram = mask_axis(spectrogram, "time", start, end)
|
||||
train_example = mask_axis(train_example, "time", start, end)
|
||||
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
max_perc: float = 0.10
|
||||
max_masks: int = 3
|
||||
return train_example
|
||||
|
||||
|
||||
def mask_frequency(
|
||||
example: xr.Dataset,
|
||||
max_perc: float = 0.10,
|
||||
max_masks: int = 3,
|
||||
train_example: xr.Dataset,
|
||||
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
||||
max_num_masks: int = 3,
|
||||
) -> xr.Dataset:
|
||||
"""Mask a random section of the frequency axis."""
|
||||
num_masks = np.random.randint(1, max_masks + 1)
|
||||
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
|
||||
|
||||
spectrogram = example["spectrogram"]
|
||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||
|
||||
freq_coord = train_example.coords["frequency"]
|
||||
min_freq = freq_coord.min()
|
||||
max_freq = freq_coord.max()
|
||||
|
||||
for _ in range(num_masks):
|
||||
mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq)
|
||||
mask_size = np.random.uniform(0, max_freq_mask)
|
||||
start = np.random.uniform(min_freq, max_freq - mask_size)
|
||||
end = start + mask_size
|
||||
spectrogram = mask_axis(spectrogram, "frequency", start, end)
|
||||
train_example = mask_axis(train_example, "frequency", start, end)
|
||||
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
return train_example
|
||||
|
||||
|
||||
class AugmentationsConfig(BaseConfig):
|
||||
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
||||
echo: EchoAugmentationConfig = Field(
|
||||
default_factory=EchoAugmentationConfig
|
||||
)
|
||||
volume: VolumeAugmentationConfig = Field(
|
||||
default_factory=VolumeAugmentationConfig
|
||||
)
|
||||
warp: WarpAugmentationConfig = Field(
|
||||
default_factory=WarpAugmentationConfig
|
||||
)
|
||||
time_mask: TimeMaskAugmentationConfig = Field(
|
||||
default_factory=TimeMaskAugmentationConfig
|
||||
)
|
||||
frequency_mask: FrequencyMaskAugmentationConfig = Field(
|
||||
default_factory=FrequencyMaskAugmentationConfig
|
||||
)
|
||||
|
||||
|
||||
def load_agumentation_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> AugmentationsConfig:
|
||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||
|
||||
|
||||
def should_apply(config: BaseAugmentationConfig) -> bool:
|
||||
if not config.enable:
|
||||
return False
|
||||
|
||||
return np.random.uniform() < config.probability
|
||||
|
||||
|
||||
def augment_example(
|
||||
example: xr.Dataset,
|
||||
config: AugmentationsConfig,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
others: Optional[Callable[[], xr.Dataset]] = None,
|
||||
) -> xr.Dataset:
|
||||
if should_apply(config.mix) and (others is not None):
|
||||
other = others()
|
||||
example = mix_examples(
|
||||
example,
|
||||
other,
|
||||
min_weight=config.mix.min_weight,
|
||||
max_weight=config.mix.max_weight,
|
||||
config=preprocessing_config,
|
||||
)
|
||||
|
||||
if should_apply(config.echo):
|
||||
example = add_echo(
|
||||
example,
|
||||
max_delay=config.echo.max_delay,
|
||||
min_weight=config.echo.min_weight,
|
||||
max_weight=config.echo.max_weight,
|
||||
config=preprocessing_config,
|
||||
)
|
||||
|
||||
if should_apply(config.volume):
|
||||
example = scale_volume(
|
||||
example,
|
||||
max_scaling=config.volume.max_scaling,
|
||||
min_scaling=config.volume.min_scaling,
|
||||
)
|
||||
|
||||
if should_apply(config.warp):
|
||||
example = warp_spectrogram(
|
||||
example,
|
||||
delta=config.warp.delta,
|
||||
)
|
||||
|
||||
if should_apply(config.time_mask):
|
||||
example = mask_time(
|
||||
example,
|
||||
max_perc=config.time_mask.max_perc,
|
||||
max_mask=config.time_mask.max_masks,
|
||||
)
|
||||
|
||||
if should_apply(config.frequency_mask):
|
||||
example = mask_frequency(
|
||||
example,
|
||||
max_perc=config.frequency_mask.max_perc,
|
||||
max_masks=config.frequency_mask.max_masks,
|
||||
)
|
||||
|
||||
return example
|
||||
AUGMENTATIONS: List[Augmentation] = [
|
||||
select_random_subclip,
|
||||
add_echo,
|
||||
scale_volume,
|
||||
mask_time,
|
||||
mask_frequency,
|
||||
]
|
||||
|
@ -1,31 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
"OptimizerConfig",
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
]
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
batch_size: int = 32
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
@ -1,21 +1,12 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, Sequence, Union
|
||||
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.tensors import adjust_width
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
augment_example,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessingConfig
|
||||
|
||||
__all__ = [
|
||||
@ -35,113 +26,74 @@ class TrainExample(NamedTuple):
|
||||
idx: torch.Tensor
|
||||
|
||||
|
||||
class SubclipConfig(BaseConfig):
|
||||
duration: Optional[float] = None
|
||||
width: int = 512
|
||||
random: bool = False
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||
augmentation: AugmentationsConfig = Field(
|
||||
default_factory=AugmentationsConfig
|
||||
)
|
||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
filenames: Sequence[PathLike],
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
|
||||
):
|
||||
self.filenames = filenames
|
||||
self.subclip = subclip
|
||||
self.augmentation = augmentation
|
||||
self.preprocessing = preprocessing or PreprocessingConfig()
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
|
||||
if self.augmentation:
|
||||
dataset = augment_example(
|
||||
dataset,
|
||||
self.augmentation,
|
||||
preprocessing_config=self.preprocessing,
|
||||
others=self.get_random_example,
|
||||
)
|
||||
|
||||
data = self.load(idx)
|
||||
return TrainExample(
|
||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||
detection_heatmap=self.to_tensor(dataset["detection"]),
|
||||
class_heatmap=self.to_tensor(dataset["class"]),
|
||||
size_heatmap=self.to_tensor(dataset["size"]),
|
||||
spec=data["spectrogram"],
|
||||
detection_heatmap=data["detection"],
|
||||
class_heatmap=data["class"],
|
||||
size_heatmap=data["size"],
|
||||
idx=torch.tensor(idx),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(
|
||||
cls,
|
||||
directory: PathLike,
|
||||
extension: str = ".nc",
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
return cls(
|
||||
get_files(directory, extension),
|
||||
subclip=subclip,
|
||||
augmentation=augmentation,
|
||||
preprocessing=preprocessing,
|
||||
)
|
||||
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
||||
return cls(get_files(directory, extension))
|
||||
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
idx = np.random.randint(0, len(self))
|
||||
def load(self, idx) -> Dict[str, torch.Tensor]:
|
||||
dataset = self.get_dataset(idx)
|
||||
return {
|
||||
"spectrogram": torch.tensor(
|
||||
dataset["spectrogram"].values
|
||||
).unsqueeze(0),
|
||||
"detection": torch.tensor(dataset["detection"].values),
|
||||
"class": torch.tensor(dataset["class"].values),
|
||||
"size": torch.tensor(dataset["size"].values),
|
||||
}
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset:
|
||||
if self.transform is not None:
|
||||
return self.transform(dataset)
|
||||
|
||||
return dataset
|
||||
|
||||
def get_dataset(self, idx) -> xr.Dataset:
|
||||
def get_dataset(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])
|
||||
|
||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||
return data.ClipAnnotation.model_validate_json(
|
||||
self.get_dataset(idx).attrs["clip_annotation"]
|
||||
)
|
||||
def get_spectrogram(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
||||
|
||||
def to_tensor(
|
||||
self,
|
||||
array: xr.DataArray,
|
||||
dtype=np.float32,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.tensor(array.values.astype(dtype))
|
||||
def get_detection_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["detection"]
|
||||
|
||||
if not self.subclip:
|
||||
return tensor
|
||||
def get_class_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["class"]
|
||||
|
||||
width = self.subclip.width
|
||||
return adjust_width(tensor, width)
|
||||
def get_size_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["size"]
|
||||
|
||||
def get_clip_annotation(self, idx):
|
||||
filename = self.filenames[idx]
|
||||
dataset = xr.open_dataset(filename)
|
||||
clip_annotation = dataset.attrs["clip_annotation"]
|
||||
return data.ClipAnnotation.model_validate_json(clip_annotation)
|
||||
|
||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
def get_preprocessing_configuration(self, idx):
|
||||
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
|
||||
return PreprocessingConfig.model_validate_json(config)
|
||||
|
@ -1,66 +1,16 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
|
||||
from batdetect2.train.targets import build_encoder, get_class_names
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
clip_prediction: data.ClipPrediction,
|
||||
) -> List[data.Match]:
|
||||
annotated_sound_events = [
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_sound_events = [
|
||||
sound_event_prediction
|
||||
for sound_event_prediction in clip_prediction.sound_events
|
||||
if sound_event_prediction.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
annotated_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in annotated_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
matches = []
|
||||
for id1, id2, affinity in match_geometries(
|
||||
annotated_geometries,
|
||||
predicted_geometries,
|
||||
):
|
||||
target = annotated_sound_events[id1] if id1 is not None else None
|
||||
source = predicted_sound_events[id2] if id2 is not None else None
|
||||
matches.append(
|
||||
data.Match(source=source, target=target, affinity=affinity)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
||||
ret = []
|
||||
|
||||
for match in matches:
|
||||
pass
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
auc,
|
||||
balanced_accuracy_score,
|
||||
roc_curve,
|
||||
)
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
# classification error
|
||||
pred_int = (pred > prob).astype(np.int32)
|
||||
pred_int = (pred > prob).astype(np.int)
|
||||
class_acc = (pred_int == gt).mean() * 100.0
|
||||
|
||||
# ROC - area under curve
|
||||
@ -75,6 +25,7 @@ def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
|
||||
def calc_average_precision(recall, precision):
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
@ -140,6 +91,7 @@ def compute_pre_rec(
|
||||
pred_class = []
|
||||
file_ids = []
|
||||
for pid, pp in enumerate(preds):
|
||||
|
||||
# filter predicted calls that are too near the start or end of the file
|
||||
file_dur = gts[pid]["duration"]
|
||||
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
||||
@ -189,6 +141,7 @@ def compute_pre_rec(
|
||||
gt_generic_class = []
|
||||
num_positives = 0
|
||||
for gg in gts:
|
||||
|
||||
# filter ground truth calls that are too near the start or end of the file
|
||||
file_dur = gg["duration"]
|
||||
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
||||
@ -252,6 +205,7 @@ def compute_pre_rec(
|
||||
|
||||
# valid detection that has not already been assigned
|
||||
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
||||
|
||||
count_as_true_pos = True
|
||||
if eval_mode == "top_class" and (
|
||||
gt_class[gt_id][det_ind] != pred_class[ind]
|
@ -1,82 +0,0 @@
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
|
||||
class TrainInputs(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: LabeledDataset[TrainInputs],
|
||||
validation_dataset: LabeledDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
num_epochs * len(train_loader),
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
device,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
|
||||
def train_single_epoch(
|
||||
model: DetectionModel,
|
||||
train_loader: DataLoader,
|
||||
optimizer: Adam,
|
||||
device: torch.device,
|
||||
scheduler: CosineAnnealingLR,
|
||||
):
|
||||
model.train()
|
||||
train_loss = tu.AverageMeter()
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
56
batdetect2/train/light.py
Normal file
56
batdetect2/train/light.py
Normal file
@ -0,0 +1,56 @@
|
||||
import pytorch_lightning as L
|
||||
from torch import Tensor, optim
|
||||
|
||||
from batdetect2.models.typing import DetectionModel, ModelOutput
|
||||
from batdetect2.train import losses
|
||||
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LitDetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class LitDetectorModel(L.LightningModule):
|
||||
model: DetectionModel
|
||||
|
||||
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
) -> Tensor:
|
||||
detection_loss = losses.focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
)
|
||||
|
||||
size_loss = losses.bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = losses.focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
valid_mask=valid_mask,
|
||||
)
|
||||
|
||||
return detection_loss + size_loss + classification_loss
|
||||
|
||||
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
|
||||
outputs: ModelOutput = self.model(batch.spec)
|
||||
loss = self.compute_loss(outputs, batch)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||
return [optimizer], [scheduler]
|
@ -1,23 +1,7 @@
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
__all__ = [
|
||||
"bbox_size_loss",
|
||||
"compute_loss",
|
||||
"focal_loss",
|
||||
"mse_loss",
|
||||
]
|
||||
|
||||
|
||||
class SizeLossConfig(BaseConfig):
|
||||
weight: float = 0.1
|
||||
|
||||
|
||||
def bbox_size_loss(
|
||||
@ -33,11 +17,6 @@ def bbox_size_loss(
|
||||
)
|
||||
|
||||
|
||||
class FocalLossConfig(BaseConfig):
|
||||
beta: float = 4
|
||||
alpha: float = 2
|
||||
|
||||
|
||||
def focal_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
@ -65,7 +44,7 @@ def focal_loss(
|
||||
)
|
||||
|
||||
if weights is not None:
|
||||
pos_loss = pos_loss * torch.tensor(weights)
|
||||
pos_loss = pos_loss * weights
|
||||
# neg_loss = neg_loss*weights
|
||||
|
||||
if valid_mask is not None:
|
||||
@ -96,71 +75,3 @@ def mse_loss(
|
||||
else:
|
||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
return op
|
||||
|
||||
|
||||
class DetectionLossConfig(BaseConfig):
|
||||
weight: float = 1.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
|
||||
|
||||
class ClassificationLossConfig(BaseConfig):
|
||||
weight: float = 2.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
class_weights: Optional[list[float]] = None
|
||||
|
||||
|
||||
class LossConfig(BaseConfig):
|
||||
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
||||
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
||||
classification: ClassificationLossConfig = Field(
|
||||
default_factory=ClassificationLossConfig
|
||||
)
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
|
||||
|
||||
def compute_loss(
|
||||
batch: TrainExample,
|
||||
outputs: ModelOutput,
|
||||
conf: LossConfig,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
) -> Losses:
|
||||
detection_loss = focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
beta=conf.detection.focal.beta,
|
||||
alpha=conf.detection.focal.alpha,
|
||||
)
|
||||
|
||||
size_loss = bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
weights=class_weights,
|
||||
valid_mask=valid_mask,
|
||||
beta=conf.classification.focal.beta,
|
||||
alpha=conf.classification.focal.alpha,
|
||||
)
|
||||
|
||||
total = (
|
||||
detection_loss * conf.detection.weight
|
||||
+ size_loss * conf.size.weight
|
||||
+ classification_loss * conf.classification.weight
|
||||
)
|
||||
|
||||
return Losses(
|
||||
detection=detection_loss,
|
||||
size=size_loss,
|
||||
classification=classification_loss,
|
||||
total=total,
|
||||
)
|
||||
|
@ -1,28 +1,20 @@
|
||||
"""Module for preprocessing data for training."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
from tqdm.auto import tqdm
|
||||
from multiprocessing import Pool
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import (
|
||||
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
||||
from batdetect2.data.preprocessing import (
|
||||
preprocess_audio_clip,
|
||||
PreprocessingConfig,
|
||||
compute_spectrogram,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_encoder,
|
||||
build_sound_event_filter,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
@ -30,76 +22,31 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"preprocess_single_annotation",
|
||||
"generate_train_example",
|
||||
"TrainPreprocessingConfig",
|
||||
]
|
||||
|
||||
|
||||
class TrainPreprocessingConfig(BaseConfig):
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
class_mapper: ClassMapper,
|
||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a training example."""
|
||||
config = TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing_config or PreprocessingConfig(),
|
||||
target=target_config or TargetConfig(),
|
||||
labels=label_config or LabelConfig(),
|
||||
)
|
||||
|
||||
wave = load_clip_audio(
|
||||
spectrogram = preprocess_audio_clip(
|
||||
clip_annotation.clip,
|
||||
config=config.preprocessing.audio,
|
||||
config=preprocessing_config,
|
||||
)
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
wave,
|
||||
config=config.preprocessing.spectrogram,
|
||||
)
|
||||
|
||||
filter_fn = build_sound_event_filter(
|
||||
include=config.target.include,
|
||||
exclude=config.target.exclude,
|
||||
)
|
||||
|
||||
selected_events = [
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
|
||||
encoder = build_encoder(
|
||||
config.target.classes,
|
||||
replacement_rules=config.target.replace,
|
||||
)
|
||||
class_names = get_class_names(config.target.classes)
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
selected_events,
|
||||
clip_annotation,
|
||||
spectrogram,
|
||||
class_names,
|
||||
encoder,
|
||||
target_sigma=config.labels.heatmaps.sigma,
|
||||
position=config.labels.heatmaps.position,
|
||||
time_scale=config.labels.heatmaps.time_scale,
|
||||
frequency_scale=config.labels.heatmaps.frequency_scale,
|
||||
class_mapper,
|
||||
target_sigma=target_sigma,
|
||||
)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
{
|
||||
# NOTE: Need to rename the time dimension to avoid conflicts with
|
||||
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||
# the spectrogram and the heatmaps to the same temporal resolution
|
||||
# as the waveform.
|
||||
"audio": wave.rename({"time": "audio_time"}),
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
@ -109,12 +56,9 @@ def generate_train_example(
|
||||
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
config=config.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude_defaults=True,
|
||||
exclude_unset=True,
|
||||
),
|
||||
preprocessing_configuration=preprocessing_config.model_dump_json(),
|
||||
target_sigma=target_sigma,
|
||||
clip_annotation=clip_annotation.model_dump_json(),
|
||||
)
|
||||
|
||||
|
||||
@ -133,54 +77,36 @@ def save_to_file(
|
||||
)
|
||||
|
||||
|
||||
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
||||
"""Load configuration from file."""
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_file():
|
||||
warnings.warn(f"Config file not found: {path}. Using default config.")
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
try:
|
||||
return PreprocessingConfig.model_validate_json(path.read_text())
|
||||
except ValueError as e:
|
||||
warnings.warn(
|
||||
f"Failed to load config file: {e}. Using default config."
|
||||
)
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
return f"{clip_annotation.uuid}.nc"
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
with Pool(max_workers) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
pool.imap_unordered(
|
||||
partial(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
config: PreprocessingConfig,
|
||||
class_mapper: ClassMapper,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
) -> None:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
@ -193,16 +119,50 @@ def preprocess_single_annotation(
|
||||
if path.is_file() and replace:
|
||||
path.unlink()
|
||||
|
||||
try:
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
class_mapper,
|
||||
preprocessing_config=config,
|
||||
target_sigma=target_sigma,
|
||||
)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||
) from error
|
||||
|
||||
save_to_file(sample, path)
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
class_mapper: ClassMapper,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if config is None:
|
||||
config = PreprocessingConfig()
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
with Pool(max_workers) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
pool.imap_unordered(
|
||||
partial(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
config=config,
|
||||
class_mapper=class_mapper,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
target_sigma=target_sigma,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
)
|
||||
)
|
||||
|
@ -1,181 +0,0 @@
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"load_target_config",
|
||||
"build_encoder",
|
||||
"build_decoder",
|
||||
"filter_sound_event",
|
||||
]
|
||||
|
||||
|
||||
class ReplaceConfig(BaseConfig):
|
||||
"""Configuration for replacing tags."""
|
||||
|
||||
original: TagInfo
|
||||
replacement: TagInfo
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Configuration for target generation."""
|
||||
|
||||
classes: List[TagInfo] = Field(
|
||||
default_factory=lambda: [
|
||||
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
|
||||
]
|
||||
)
|
||||
generic_class: Optional[TagInfo] = Field(
|
||||
default_factory=lambda: TagInfo(key="class", value="Bat")
|
||||
)
|
||||
|
||||
include: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
|
||||
)
|
||||
|
||||
exclude: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [
|
||||
TagInfo(key="class", value=""),
|
||||
TagInfo(key="class", value=" "),
|
||||
TagInfo(key="class", value="Unknown"),
|
||||
]
|
||||
)
|
||||
|
||||
replace: Optional[List[ReplaceConfig]] = None
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
include: Optional[List[TagInfo]] = None,
|
||||
exclude: Optional[List[TagInfo]] = None,
|
||||
) -> Callable[[data.SoundEventAnnotation], bool]:
|
||||
include_tags = (
|
||||
{get_tag_from_info(tag) for tag in include} if include else None
|
||||
)
|
||||
exclude_tags = (
|
||||
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
||||
)
|
||||
return partial(
|
||||
filter_sound_event,
|
||||
include=include_tags,
|
||||
exclude=exclude_tags,
|
||||
)
|
||||
|
||||
|
||||
def get_tag_label(tag_info: TagInfo) -> str:
|
||||
return tag_info.label if tag_info.label else tag_info.value
|
||||
|
||||
|
||||
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||
return sorted({get_tag_label(tag) for tag in classes})
|
||||
|
||||
|
||||
def build_replacer(
|
||||
rules: List[ReplaceConfig],
|
||||
) -> Callable[[data.Tag], data.Tag]:
|
||||
mapping = {
|
||||
get_tag_from_info(rule.original): get_tag_from_info(rule.replacement)
|
||||
for rule in rules
|
||||
}
|
||||
|
||||
def replacer(tag: data.Tag) -> data.Tag:
|
||||
return mapping.get(tag, tag)
|
||||
|
||||
return replacer
|
||||
|
||||
|
||||
def build_encoder(
|
||||
classes: List[TagInfo],
|
||||
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
||||
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||
|
||||
tag_mapping = {
|
||||
tag: get_tag_label(tag_info)
|
||||
for tag, tag_info in zip(target_tags, classes)
|
||||
}
|
||||
|
||||
replacer = (
|
||||
build_replacer(replacement_rules) if replacement_rules else lambda x: x
|
||||
)
|
||||
|
||||
def encoder(
|
||||
tags: Iterable[data.Tag],
|
||||
) -> Optional[str]:
|
||||
sanitized_tags = {replacer(tag) for tag in tags}
|
||||
|
||||
intersection = sanitized_tags & target_tags
|
||||
|
||||
if not intersection:
|
||||
return None
|
||||
|
||||
first = intersection.pop()
|
||||
return tag_mapping[first]
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build_decoder(
|
||||
classes: List[TagInfo],
|
||||
) -> Callable[[str], List[data.Tag]]:
|
||||
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||
tag_mapping = {
|
||||
get_tag_label(tag_info): tag
|
||||
for tag, tag_info in zip(target_tags, classes)
|
||||
}
|
||||
|
||||
def decoder(label: str) -> List[data.Tag]:
|
||||
tag = tag_mapping.get(label)
|
||||
return [tag] if tag else []
|
||||
|
||||
return decoder
|
||||
|
||||
|
||||
def filter_sound_event(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
include: Optional[Set[data.Tag]] = None,
|
||||
exclude: Optional[Set[data.Tag]] = None,
|
||||
) -> bool:
|
||||
tags = set(sound_event_annotation.tags)
|
||||
|
||||
if include is not None and not tags & include:
|
||||
return False
|
||||
|
||||
if exclude is not None and tags & exclude:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: Path, field: Optional[str] = None
|
||||
) -> TargetConfig:
|
||||
return load_config(path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
DEFAULT_SPECIES_LIST = [
|
||||
"Barbastellus barbastellus",
|
||||
"Eptesicus serotinus",
|
||||
"Myotis alcathoe",
|
||||
"Myotis bechsteinii",
|
||||
"Myotis brandtii",
|
||||
"Myotis daubentonii",
|
||||
"Myotis mystacinus",
|
||||
"Myotis nattereri",
|
||||
"Nyctalus leisleri",
|
||||
"Nyctalus noctula",
|
||||
"Pipistrellus nathusii",
|
||||
"Pipistrellus pipistrellus",
|
||||
"Pipistrellus pygmaeus",
|
||||
"Plecotus auritus",
|
||||
"Plecotus austriacus",
|
||||
"Rhinolophus ferrumequinum",
|
||||
"Rhinolophus hipposideros",
|
||||
]
|
@ -1,68 +1,82 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import Trainer
|
||||
from soundevent.data import PathLike
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"TrainerConfig",
|
||||
"load_trainer_config",
|
||||
]
|
||||
from batdetect2.data.datasets import ClipAnnotationDataset
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
|
||||
|
||||
class TrainerConfig(BaseConfig):
|
||||
accelerator: str = "auto"
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
check_val_every_n_epoch: int = 1
|
||||
devices: Union[str, int] = "auto"
|
||||
enable_checkpointing: bool = True
|
||||
gradient_clip_val: Optional[float] = None
|
||||
limit_train_batches: Optional[Union[int, float]] = None
|
||||
limit_test_batches: Optional[Union[int, float]] = None
|
||||
limit_val_batches: Optional[Union[int, float]] = None
|
||||
log_every_n_steps: Optional[int] = None
|
||||
max_epochs: Optional[int] = None
|
||||
min_epochs: Optional[int] = 100
|
||||
max_steps: Optional[int] = None
|
||||
min_steps: Optional[int] = None
|
||||
max_time: Optional[str] = None
|
||||
precision: Optional[str] = None
|
||||
reload_dataloaders_every_n_epochs: Optional[int] = None
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
class TrainInputs(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
||||
return load_config(path, schema=TrainerConfig, field=field)
|
||||
|
||||
|
||||
def train(
|
||||
module: LightningModule,
|
||||
train_dataset: LabeledDataset,
|
||||
trainer_config: Optional[TrainerConfig] = None,
|
||||
dev_run: bool = False,
|
||||
overfit_batches: bool = False,
|
||||
profiler: Optional[str] = None,
|
||||
def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
validation_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
trainer_config = trainer_config or TrainerConfig()
|
||||
trainer = Trainer(
|
||||
**trainer_config.model_dump(
|
||||
exclude_unset=True,
|
||||
exclude_none=True,
|
||||
),
|
||||
fast_dev_run=dev_run,
|
||||
overfit_batches=overfit_batches,
|
||||
profiler=profiler,
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
num_epochs * len(train_loader),
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=module.config.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=7,
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
device,
|
||||
scheduler,
|
||||
)
|
||||
trainer.fit(module, train_dataloaders=train_loader)
|
||||
|
||||
|
||||
def train_single_epoch(
|
||||
model: DetectionModel,
|
||||
train_loader: DataLoader,
|
||||
optimizer: Adam,
|
||||
device: torch.device,
|
||||
scheduler: CosineAnnealingLR,
|
||||
):
|
||||
model.train()
|
||||
train_loss = tu.AverageMeter()
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
@ -16,6 +16,7 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
||||
|
||||
|
||||
def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
||||
@ -143,6 +144,7 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
|
||||
def split_same(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
@ -69,8 +69,7 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
|
||||
|
||||
|
||||
def standardize_low_freq(
|
||||
data: List[types.FileAnnotation],
|
||||
class_of_interest: str,
|
||||
data: List[types.FileAnnotation], class_of_interest: str,
|
||||
) -> List[types.FileAnnotation]:
|
||||
# address the issue of highly variable low frequency annotations
|
||||
# this often happens for contstant frequency calls
|
@ -1,12 +1,13 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from typing import TypedDict
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
@ -14,8 +15,9 @@ try:
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired # type: ignore
|
||||
from typing import NotRequired
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
@ -595,7 +597,8 @@ class FeatureExtractor(Protocol):
|
||||
self,
|
||||
prediction: Prediction,
|
||||
**kwargs: Any,
|
||||
) -> float: ...
|
||||
) -> float:
|
||||
...
|
||||
|
||||
|
||||
class DatasetDict(TypedDict):
|
||||
|
@ -6,8 +6,6 @@ import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
from . import wavfile
|
||||
|
||||
__all__ = [
|
||||
@ -17,44 +15,20 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def time_to_x_coords(
|
||||
time_in_file: float,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
nfft = np.floor(window_duration * samplerate) # int() uses floor
|
||||
noverlap = np.floor(window_overlap * nfft)
|
||||
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||
|
||||
|
||||
def x_coords_to_time(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
return ((x_pos * n_step) + n_overlap) / samplerate
|
||||
# NOTE this is also defined in post_process
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
|
||||
def x_coord_to_sample(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
x_pos = int(x_pos / resize_factor)
|
||||
return int((x_pos * n_step) + n_overlap)
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
@ -90,7 +64,7 @@ def generate_spectrogram(
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
).astype(np.float32)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
@ -100,7 +74,7 @@ def generate_spectrogram(
|
||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling * spec)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec, sampling_rate)
|
||||
spec = pcen(spec , sampling_rate)
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
@ -220,118 +194,55 @@ def load_audio(
|
||||
return sampling_rate, audio_raw
|
||||
|
||||
|
||||
def compute_spectrogram_width(
|
||||
length: int,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = int(window_duration * samplerate)
|
||||
n_overlap = int(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
width = (length - n_overlap) // n_step
|
||||
return int(width * resize_factor)
|
||||
|
||||
|
||||
def pad_audio(
|
||||
audio: np.ndarray,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
|
||||
fixed_width: Optional[int] = None,
|
||||
audio_raw,
|
||||
fs,
|
||||
ms,
|
||||
overlap_perc,
|
||||
resize_factor,
|
||||
divide_factor,
|
||||
fixed_width=None,
|
||||
):
|
||||
"""Pad audio to be evenly divisible by `divide_factor`.
|
||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||
# will be evenly divisible by `divide_factor`
|
||||
# Also deals with very short audio clips and fixed_width during training
|
||||
|
||||
This function pads the audio signal with zeros to ensure that the
|
||||
generated spectrogram length will be evenly divisible by `divide_factor`.
|
||||
This is important for the model to work correctly.
|
||||
# This code could be clearer, clean up
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
step = nfft - noverlap
|
||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||
spec_width = (audio_raw.shape[0] - noverlap) // step
|
||||
spec_width_rs = spec_width * resize_factor
|
||||
|
||||
This `divide_factor` comes from the model architecture as it downscales
|
||||
the spectrogram by this factor, so the input must be divisible by this
|
||||
integer number.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
The audio signal.
|
||||
samplerate : int
|
||||
The sampling rate of the audio signal.
|
||||
window_size : float
|
||||
The window size in seconds used for the spectrogram computation.
|
||||
window_overlap : float
|
||||
The overlap between windows in the spectrogram computation.
|
||||
resize_factor : float
|
||||
This factor is used to resize the spectrogram after the STFT
|
||||
computation. Default is 0.5 which means that the spectrogram will be
|
||||
reduced by half. Important to take into account for the final size of
|
||||
the spectrogram.
|
||||
divide_factor : int
|
||||
The factor by which the spectrogram will be divided.
|
||||
fixed_width : int, optional
|
||||
If provided, the audio will be padded or cut so that the resulting
|
||||
spectrogram width will be equal to this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The padded audio signal.
|
||||
"""
|
||||
spec_width = compute_spectrogram_width(
|
||||
audio.shape[0],
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
if fixed_width is not None and spec_width < fixed_width:
|
||||
# too small
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
)
|
||||
|
||||
if fixed_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
fixed_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
elif fixed_width is not None and spec_width > fixed_width:
|
||||
# too big
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = audio_raw[:diff]
|
||||
|
||||
if spec_width < fixed_width:
|
||||
elif (
|
||||
spec_width_rs < min_size
|
||||
or (np.floor(spec_width_rs) % divide_factor) != 0
|
||||
):
|
||||
# need to be at least min_size
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
if spec_width > fixed_width:
|
||||
return audio[:target_samples]
|
||||
|
||||
return audio
|
||||
|
||||
min_width = int(divide_factor / resize_factor)
|
||||
|
||||
if spec_width < min_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
min_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
||||
div_amt = np.maximum(1, div_amt)
|
||||
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
|
||||
diff = target_size * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
if (spec_width % divide_factor) == 0:
|
||||
return audio
|
||||
|
||||
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
|
||||
target_samples = x_coord_to_sample(
|
||||
target_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
return audio_raw
|
||||
|
||||
|
||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
|
@ -8,11 +8,6 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from numpy.exceptions import AxisError
|
||||
except ImportError:
|
||||
from numpy import AxisError # type: ignore
|
||||
|
||||
import batdetect2.detector.compute_features as feats
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.utils.audio_utils as au
|
||||
@ -85,7 +80,6 @@ def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
weights_only: bool = True,
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
@ -106,11 +100,7 @@ def load_model(
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
|
||||
net_params = torch.load(
|
||||
model_path,
|
||||
map_location=device,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
|
||||
params = net_params["params"]
|
||||
|
||||
@ -252,7 +242,7 @@ def format_single_result(
|
||||
)
|
||||
class_name = class_names[np.argmax(class_overall)]
|
||||
annotations = get_annotations_from_preds(predictions, class_names)
|
||||
except (AxisError, ValueError):
|
||||
except (np.AxisError, ValueError):
|
||||
# No detections
|
||||
class_overall = np.zeros(len(class_names))
|
||||
class_name = "None"
|
||||
@ -409,7 +399,7 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
|
||||
def compute_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
sampling_rate: float,
|
||||
params: SpectrogramParameters,
|
||||
device: torch.device,
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
@ -627,7 +617,7 @@ def process_spectrogram(
|
||||
|
||||
def _process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
sampling_rate: float,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
@ -748,7 +738,9 @@ def process_file(
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||
orig_samp_rate = file_samp_rate * float(
|
||||
config.get("time_expansion", 1.0) or 1.0
|
||||
)
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
|
@ -417,9 +417,7 @@ def plot_confusion_matrix(
|
||||
cm_norm = cm.sum(1)
|
||||
|
||||
valid_inds = np.where(cm_norm > 0)[0]
|
||||
cm[valid_inds, :] = (
|
||||
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
)
|
||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
||||
|
||||
if verbose:
|
||||
|
@ -155,9 +155,9 @@ class InteractivePlotter:
|
||||
|
||||
# draw bounding box around call
|
||||
self.ax[1].patches[0].remove()
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[
|
||||
1
|
||||
] / (1.0 + 2.0 * self.spec_pad)
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
||||
1.0 + 2.0 * self.spec_pad
|
||||
)
|
||||
xx = w_diff + self.spec_pad * spec_width_orig
|
||||
ww = spec_width_orig
|
||||
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
||||
@ -183,9 +183,7 @@ class InteractivePlotter:
|
||||
round(self.call_info[self.current_id]["start_time"], 3)
|
||||
)
|
||||
+ ", prob="
|
||||
+ str(
|
||||
round(self.call_info[self.current_id]["det_prob"], 3)
|
||||
)
|
||||
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
||||
)
|
||||
self.ax[0].set_xlabel(info_str)
|
||||
|
||||
|
@ -8,7 +8,6 @@ Functions
|
||||
`write`: Write a numpy array as a WAV file.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
@ -157,6 +156,7 @@ def read(filename, mmap=False):
|
||||
fid = open(filename, "rb")
|
||||
|
||||
try:
|
||||
|
||||
# some files seem to have the size recorded in the header greater than
|
||||
# the actual file size.
|
||||
fid.seek(0, os.SEEK_END)
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -6,11 +6,11 @@
|
||||
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:02.699871Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:02.699590Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:02.710312Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:02.709798Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:20.598611Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:20.596274Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:20.670888Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:20.668193Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:20.598423Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -25,11 +25,11 @@
|
||||
"id": "326c5432-94e6-4abf-a332-fe902559461b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:02.711324Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:02.711067Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.092380Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.091830Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:20.676278Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:20.675545Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.872556Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.871725Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:20.676206Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@ -37,7 +37,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
@ -45,35 +45,26 @@
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from typing import List, Optional\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from batdetect2.train.modules import DetectorModel\n",
|
||||
"from soundevent import data\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"from batdetect2.data.labels import ClassMapper\n",
|
||||
"from batdetect2.models.detectors import DetectorModel\n",
|
||||
"from batdetect2.train.augmentations import (\n",
|
||||
" add_echo,\n",
|
||||
" select_random_subclip,\n",
|
||||
" warp_spectrogram,\n",
|
||||
")\n",
|
||||
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
|
||||
"from batdetect2.train.preprocess import PreprocessingConfig\n",
|
||||
"from soundevent import data\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from soundevent.types import ClassMapper\n",
|
||||
"from torch.utils.data import DataLoader"
|
||||
"from batdetect2.train.preprocess import PreprocessingConfig"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
|
||||
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
|
||||
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
|
||||
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
|
||||
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
|
||||
}
|
||||
},
|
||||
"id": "fa202af2-5c0d-4b5d-91a3-097ef5cd4272",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training Datasets"
|
||||
]
|
||||
@ -84,11 +75,11 @@
|
||||
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.093487Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.092990Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.121636Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.121143Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:25.874255Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:25.873473Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.912952Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.911844Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:25.874206Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -102,11 +93,11 @@
|
||||
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.122685Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.122270Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.151386Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.150788Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:25.914456Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:25.914027Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.954939Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.953906Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:25.914410Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -120,11 +111,11 @@
|
||||
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.152327Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.152060Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.184041Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.183372Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:25.956758Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:25.956260Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.997664Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.996074Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:25.956705Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -138,11 +129,11 @@
|
||||
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.186393Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.186117Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.220175Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.219322Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.003195Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.002783Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.054400Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.053294Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.003158Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -161,11 +152,11 @@
|
||||
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.221653Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.221242Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.260977Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.260375Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.056060Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.055706Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.103227Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.102190Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.056025Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -177,6 +168,7 @@
|
||||
" \"Myotis mystacinus\",\n",
|
||||
" \"Pipistrellus pipistrellus\",\n",
|
||||
" \"Rhinolophus ferrumequinum\",\n",
|
||||
" \"social\",\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
|
||||
@ -205,11 +197,11 @@
|
||||
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.262337Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.261775Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.309793Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.309216Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.104877Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.104538Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.159676Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.157914Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.104843Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -223,11 +215,11 @@
|
||||
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.310695Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.310438Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.366636Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.366059Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.162346Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.161885Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.374668Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.373691Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.162305Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@ -237,6 +229,7 @@
|
||||
"text": [
|
||||
"GPU available: False, used: False\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n"
|
||||
]
|
||||
}
|
||||
@ -255,11 +248,11 @@
|
||||
"id": "0b86d49d-3314-4257-94f5-f964855be385",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.367499Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.367242Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.811300Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.809823Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.375918Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.375632Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:28.829650Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:28.828219Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.375889Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@ -268,67 +261,37 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" | Name | Type | Params | Mode \n",
|
||||
"--------------------------------------------------------\n",
|
||||
"0 | feature_extractor | Net2DFast | 119 K | train\n",
|
||||
"1 | classifier | Conv2d | 54 | train\n",
|
||||
"2 | bbox | Conv2d | 18 | train\n",
|
||||
"--------------------------------------------------------\n",
|
||||
" | Name | Type | Params\n",
|
||||
"------------------------------------------------\n",
|
||||
"0 | feature_extractor | Net2DFast | 119 K \n",
|
||||
"1 | classifier | Conv2d | 54 \n",
|
||||
"2 | bbox | Conv2d | 18 \n",
|
||||
"------------------------------------------------\n",
|
||||
"119 K Trainable params\n",
|
||||
"448 Non-trainable params\n",
|
||||
"119 K Total params\n",
|
||||
"0.480 Total estimated model params size (MB)\n",
|
||||
"32 Modules in train mode\n",
|
||||
"0 Modules in eval mode\n"
|
||||
"0.480 Total estimated model params size (MB)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
|
||||
"class props shape torch.Size([3, 5, 128, 512])\n"
|
||||
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.59it/s, v_num=13]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.54it/s, v_num=13]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -338,14 +301,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.811729Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.811955Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.811858Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:28.832222Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:28.831642Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:29.000595Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:28.998078Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:28.832157Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -355,54 +319,44 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.812924Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.813260Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.813104Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
|
||||
"iopub.execute_input": "2024-07-16T00:27:29.004279Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:29.003486Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:29.595626Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:29.594734Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:29.004200Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"spec = detector.compute_spectrogram(clip_annotation.clip)\n",
|
||||
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
|
||||
"predictions = detector.compute_clip_predictions(clip_annotation.clip)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"_, ax= plt.subplots(figsize=(15, 5))\n",
|
||||
"spec.plot(ax=ax, add_colorbar=False)\n",
|
||||
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 18,
|
||||
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.815603Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.816065Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.815894Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
|
||||
"iopub.execute_input": "2024-07-16T00:28:47.178783Z",
|
||||
"iopub.status.busy": "2024-07-16T00:28:47.178143Z",
|
||||
"iopub.status.idle": "2024-07-16T00:28:47.246613Z",
|
||||
"shell.execute_reply": "2024-07-16T00:28:47.245496Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:28:47.178729Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Num predicted soundevents: 50\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
|
||||
]
|
||||
@ -410,7 +364,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
|
||||
"id": "d3883c04-d91a-4d1d-b677-196c0179dde1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@ -432,7 +386,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,23 +1,34 @@
|
||||
[tool]
|
||||
rye = { dev-dependencies = [
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"pytest>=8.1.1",
|
||||
] }
|
||||
[tool.pdm]
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.2.2",
|
||||
]
|
||||
|
||||
[project]
|
||||
name = "batdetect2"
|
||||
version = "1.1.1"
|
||||
version = "1.0.8"
|
||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||
authors = [
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
|
||||
]
|
||||
dependencies = [
|
||||
"click>=8.1.7",
|
||||
"librosa>=0.10.1",
|
||||
"matplotlib>=3.7.1",
|
||||
"numpy>=1.23.5",
|
||||
"pandas>=1.5.3",
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.3",
|
||||
"torch>=1.13.1",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"soundevent[audio,geometry,plot]>=2.0.1",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
@ -26,12 +37,8 @@ dependencies = [
|
||||
"onnx>=1.16.0",
|
||||
"lightning[extra]>=2.2.2",
|
||||
"tensorboard>=2.16.2",
|
||||
"omegaconf>=2.3.0",
|
||||
"pyyaml>=6.0.2",
|
||||
"hydra-core>=1.3.2",
|
||||
"numba>=0.60",
|
||||
]
|
||||
requires-python = ">=3.9,<3.13"
|
||||
requires-python = ">=3.9"
|
||||
readme = "README.md"
|
||||
license = { text = "CC-by-nc-4" }
|
||||
classifiers = [
|
||||
@ -39,10 +46,8 @@ classifiers = [
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||
@ -58,40 +63,41 @@ keywords = [
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
requires = ["pdm-pep517>=1.0.0"]
|
||||
build-backend = "pdm.pep517.api"
|
||||
|
||||
[project.scripts]
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"debugpy>=1.8.8",
|
||||
"hypothesis>=6.118.7",
|
||||
"pytest>=7.2.2",
|
||||
"ruff>=0.7.3",
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"basedpyright>=1.28.4",
|
||||
]
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 79
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = 79
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"librosa",
|
||||
"pandas",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||
[tool.pylsp-mypy]
|
||||
enabled = false
|
||||
live_mode = true
|
||||
strict = true
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
[tool.pydocstyle]
|
||||
convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = ["batdetect2", "tests"]
|
||||
include = [
|
||||
"bat_detect",
|
||||
"tests",
|
||||
]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
|
@ -16,6 +16,7 @@ import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.audio_utils as au
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"audio_path", type=str, help="Input directory for audio"
|
||||
@ -64,9 +65,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
# load uk data - special case
|
||||
print("\nLoading:", args["uk_split"], "\n")
|
||||
dataset_name = (
|
||||
"uk_" + args["uk_split"]
|
||||
) # should be uk_diff, or uk_same
|
||||
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
|
||||
datasets, _ = ts.get_train_test_data(
|
||||
args["ann_file"],
|
||||
args["audio_path"],
|
||||
|
@ -33,6 +33,7 @@ def filter_anns(anns, start_time, stop_time):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
||||
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
||||
@ -142,9 +143,7 @@ if __name__ == "__main__":
|
||||
|
||||
# run model and filter detections so only keep ones in relevant time range
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
results = du.process_file(
|
||||
args_cmd["audio_file"], model, run_config, device
|
||||
)
|
||||
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
|
||||
pred_anns = filter_anns(
|
||||
results["pred_dict"]["annotation"],
|
||||
args_cmd["start_time"],
|
||||
|
@ -25,9 +25,7 @@ import batdetect2.utils.plot_utils as viz
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"audio_file", type=str, help="Path to input audio file"
|
||||
)
|
||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
||||
parser.add_argument(
|
||||
"model_path", type=str, help="Path to trained BatDetect model"
|
||||
)
|
||||
|
@ -198,6 +198,7 @@ def save_summary_image(
|
||||
)
|
||||
ii = 0
|
||||
for row in ax:
|
||||
|
||||
if type(row) != np.ndarray:
|
||||
row = np.array([row])
|
||||
|
||||
@ -214,9 +215,7 @@ def save_summary_image(
|
||||
)
|
||||
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
||||
col.set_xticks([])
|
||||
col.title.set_text(
|
||||
str(ii + 1) + " " + species_names[order[ii]]
|
||||
)
|
||||
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
|
||||
col.tick_params(axis="both", which="major", labelsize=7)
|
||||
ii += 1
|
||||
|
||||
|
@ -1,109 +0,0 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_data_dir() -> Path:
|
||||
pkg_dir = Path(__file__).parent.parent
|
||||
example_data_dir = pkg_dir / "example_data"
|
||||
assert example_data_dir.exists()
|
||||
return example_data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_dir(example_data_dir: Path) -> Path:
|
||||
example_audio_dir = example_data_dir / "audio"
|
||||
assert example_audio_dir.exists()
|
||||
return example_audio_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_anns_dir(example_data_dir: Path) -> Path:
|
||||
example_anns_dir = example_data_dir / "anns"
|
||||
assert example_anns_dir.exists()
|
||||
return example_anns_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||
assert len(audio_files) == 3
|
||||
return audio_files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir() -> Path:
|
||||
dir = Path(__file__).parent / "data"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contrib_dir(data_dir) -> Path:
|
||||
dir = data_dir / "contrib"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wav_factory(tmp_path: Path):
|
||||
def _wav_factory(
|
||||
path: Optional[Path] = None,
|
||||
duration: float = 0.3,
|
||||
channels: int = 1,
|
||||
samplerate: int = 441_000,
|
||||
bit_depth: int = 16,
|
||||
) -> Path:
|
||||
path = path or tmp_path / f"{uuid.uuid4()}.wav"
|
||||
frames = int(samplerate * duration)
|
||||
shape = (frames, channels)
|
||||
subtype = f"PCM_{bit_depth}"
|
||||
|
||||
if bit_depth == 16:
|
||||
dtype = np.int16
|
||||
elif bit_depth == 32:
|
||||
dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f"Unsupported bit depth: {bit_depth}")
|
||||
|
||||
wav = np.random.uniform(
|
||||
low=np.iinfo(dtype).min,
|
||||
high=np.iinfo(dtype).max,
|
||||
size=shape,
|
||||
).astype(dtype)
|
||||
sf.write(str(path), wav, samplerate, subtype=subtype)
|
||||
return path
|
||||
|
||||
return _wav_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recording_factory(wav_factory: Callable[..., Path]):
|
||||
def _recording_factory(
|
||||
tags: Optional[list[data.Tag]] = None,
|
||||
path: Optional[Path] = None,
|
||||
recording_id: Optional[uuid.UUID] = None,
|
||||
duration: float = 1,
|
||||
channels: int = 1,
|
||||
samplerate: int = 256_000,
|
||||
time_expansion: float = 1,
|
||||
) -> data.Recording:
|
||||
path = path or wav_factory(
|
||||
duration=duration,
|
||||
channels=channels,
|
||||
samplerate=samplerate,
|
||||
)
|
||||
return data.Recording.from_file(
|
||||
path=path,
|
||||
uuid=recording_id or uuid.uuid4(),
|
||||
time_expansion=time_expansion,
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
return _recording_factory
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,13 +1,14 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2 import api
|
||||
|
||||
@ -266,6 +267,7 @@ def test_process_file_with_spec_slices():
|
||||
assert len(results["spec_slices"]) == len(detections)
|
||||
|
||||
|
||||
|
||||
def test_process_file_with_empty_predictions_does_not_fail(
|
||||
tmp_path: Path,
|
||||
):
|
||||
|
@ -1,136 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
from batdetect2.utils import audio_utils, detector_utils
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
spectrogram, _ = audio_utils.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
)
|
||||
|
||||
# convert to pytorch
|
||||
spectrogram = torch.from_numpy(spectrogram)
|
||||
|
||||
# add batch and channel dimensions
|
||||
spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# resize the spec
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(params["spec_height"] * resize_factor),
|
||||
int(spectrogram.shape[-1] * resize_factor),
|
||||
)
|
||||
spectrogram = F.interpolate(
|
||||
spectrogram,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
length,
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
|
||||
assert spectrogram.shape[-1] == expected_width
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_pad_audio_without_fixed_size(duration: float):
|
||||
# Test the pad_audio function
|
||||
# This function is used to pad audio with zeros to a specific length
|
||||
# It is used in the generate_spectrogram function
|
||||
# The function is tested with a simplepas
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
# pad the audio to be divisible by divide factor
|
||||
padded_audio = audio_utils.pad_audio(
|
||||
audio,
|
||||
samplerate=samplerate,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
divide_factor=params["spec_divide_factor"],
|
||||
)
|
||||
|
||||
# check that the padded audio is divisible by the divide factor
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
len(padded_audio),
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
|
||||
assert expected_width % params["spec_divide_factor"] == 0
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||
duration: float,
|
||||
):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
_, spectrogram = detector_utils.compute_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
torch.device("cpu"),
|
||||
)
|
||||
assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0
|
||||
|
||||
|
||||
@given(
|
||||
duration=st.floats(min_value=0.1, max_value=2),
|
||||
width=st.integers(min_value=128, max_value=1024),
|
||||
)
|
||||
def test_pad_audio_with_fixed_width(duration: float, width: int):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
# pad the audio to be divisible by divide factor
|
||||
padded_audio = audio_utils.pad_audio(
|
||||
audio,
|
||||
samplerate=samplerate,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
divide_factor=params["spec_divide_factor"],
|
||||
fixed_width=width,
|
||||
)
|
||||
|
||||
# check that the padded audio is divisible by the divide factor
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
len(padded_audio),
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
assert expected_width == width
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user