mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
71 Commits
f63307757c
...
98c6da6d42
Author | SHA1 | Date | |
---|---|---|---|
![]() |
98c6da6d42 | ||
![]() |
fe8d044af2 | ||
![]() |
62fa38557e | ||
![]() |
213b6dfd29 | ||
![]() |
bfa6049adc | ||
![]() |
e383a33cbf | ||
![]() |
7689580a24 | ||
![]() |
1338ae7431 | ||
![]() |
c2c4ac53fd | ||
![]() |
ff00da9a9a | ||
![]() |
22cf47ed39 | ||
![]() |
d9f7304a0f | ||
![]() |
451093f2da | ||
![]() |
30d3a2c92e | ||
![]() |
c17a25fa75 | ||
![]() |
29f7862153 | ||
![]() |
dbc3ff9364 | ||
![]() |
150305a273 | ||
![]() |
904e8f23ea | ||
![]() |
48e009fa9d | ||
![]() |
f7d6516550 | ||
![]() |
e9e1f7ce2f | ||
![]() |
113223be02 | ||
![]() |
35c916482c | ||
![]() |
09bd8cf423 | ||
![]() |
f6cdd4e87e | ||
![]() |
9cf159efff | ||
![]() |
36c90a600f | ||
![]() |
1f0fb14d89 | ||
![]() |
ee884da8b0 | ||
![]() |
6a9e33c729 | ||
![]() |
2100a3e483 | ||
![]() |
1d3cd2e305 | ||
![]() |
d5753b95bb | ||
![]() |
69f59ff559 | ||
![]() |
1a11174bc4 | ||
![]() |
c5c9476e52 | ||
![]() |
270b3f212d | ||
![]() |
f61d1d8c72 | ||
![]() |
4627ddd739 | ||
![]() |
3477d7b5b4 | ||
![]() |
394c66a2ee | ||
![]() |
d085b3212c | ||
![]() |
c393c5c29b | ||
![]() |
3c22ff28a7 | ||
![]() |
7dc28695b2 | ||
![]() |
505cca2dea | ||
![]() |
7906842a16 | ||
![]() |
a4b22d6590 | ||
![]() |
25e0a53ad1 | ||
![]() |
039c002796 | ||
![]() |
c97a87b2a4 | ||
![]() |
d93d8284d0 | ||
![]() |
697b5dbddb | ||
![]() |
d5bf8f5ad8 | ||
![]() |
fcbccbe012 | ||
![]() |
4917641e2c | ||
![]() |
39c3918103 | ||
![]() |
1ac3808fee | ||
![]() |
9e0ad7fd78 | ||
![]() |
95bb0985e7 | ||
![]() |
cb088359ae | ||
![]() |
c5030123aa | ||
![]() |
1c1fbd8019 | ||
![]() |
c65fe1c9f9 | ||
![]() |
d05bec880a | ||
![]() |
8597ef0a1c | ||
![]() |
2d8a7b67f8 | ||
![]() |
68351d2224 | ||
![]() |
3f34164028 | ||
![]() |
d84b7795f6 |
8
.bumpversion.cfg
Normal file
8
.bumpversion.cfg
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[bumpversion]
|
||||||
|
current_version = 1.1.1
|
||||||
|
commit = True
|
||||||
|
tag = True
|
||||||
|
|
||||||
|
[bumpversion:file:batdetect2/__init__.py]
|
||||||
|
|
||||||
|
[bumpversion:file:pyproject.toml]
|
35
.github/workflows/python-package.yml
vendored
35
.github/workflows/python-package.yml
vendored
@ -1,34 +1,29 @@
|
|||||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
|
||||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
|
||||||
|
|
||||||
name: Python package
|
name: Python package
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Install uv
|
||||||
uses: actions/setup-python@v3
|
uses: astral-sh/setup-uv@v3
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
enable-cache: true
|
||||||
- name: Install dependencies
|
cache-dependency-glob: "uv.lock"
|
||||||
run: |
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
python -m pip install --upgrade pip
|
run: uv python install ${{ matrix.python-version }}
|
||||||
python -m pip install pytest
|
- name: Install the project
|
||||||
pip install .
|
run: uv sync --all-extras --dev
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: uv run pytest
|
||||||
pytest
|
|
||||||
|
41
.github/workflows/python-publish.yml
vendored
41
.github/workflows/python-publish.yml
vendored
@ -1,11 +1,3 @@
|
|||||||
# This workflow will upload a Python Package using Twine when a release is created
|
|
||||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
|
||||||
|
|
||||||
# This workflow uses actions that are not certified by GitHub.
|
|
||||||
# They are provided by a third-party and are governed by
|
|
||||||
# separate terms of service, privacy policy, and support
|
|
||||||
# documentation.
|
|
||||||
|
|
||||||
name: Upload Python Package
|
name: Upload Python Package
|
||||||
|
|
||||||
on:
|
on:
|
||||||
@ -17,23 +9,22 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
deploy:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v3
|
||||||
with:
|
with:
|
||||||
python-version: '3.x'
|
python-version: "3.x"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install build
|
pip install build
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: python -m build
|
run: python -m build
|
||||||
- name: Publish package
|
- name: Publish package
|
||||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||||
with:
|
with:
|
||||||
user: __token__
|
user: __token__
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -103,13 +103,11 @@ experiments/*
|
|||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
*.ipynb
|
*.ipynb
|
||||||
|
|
||||||
# Bump2version
|
|
||||||
.bumpversion.cfg
|
|
||||||
|
|
||||||
# DO Include
|
# DO Include
|
||||||
!batdetect2_notebook.ipynb
|
!batdetect2_notebook.ipynb
|
||||||
!batdetect2/models/checkpoints/*.pth.tar
|
!batdetect2/models/checkpoints/*.pth.tar
|
||||||
!tests/data/*.wav
|
!tests/data/*.wav
|
||||||
!notebooks/*.ipynb
|
!notebooks/*.ipynb
|
||||||
|
!tests/data/**/*.wav
|
||||||
notebooks/lightning_logs
|
notebooks/lightning_logs
|
||||||
example_data/preprocessed
|
example_data/preprocessed
|
||||||
|
@ -1 +1,6 @@
|
|||||||
__version__ = '1.0.8'
|
import logging
|
||||||
|
|
||||||
|
numba_logger = logging.getLogger("numba")
|
||||||
|
numba_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
__version__ = "1.1.1"
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
|
from batdetect2.cli.data import data
|
||||||
|
from batdetect2.cli.train import train
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cli",
|
"cli",
|
||||||
"detect",
|
"detect",
|
||||||
|
"data",
|
||||||
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
22
batdetect2/cli/ascii.py
Normal file
22
batdetect2/cli/ascii.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
BATDETECT_ASCII_ART = """ .
|
||||||
|
=#%: .%%#
|
||||||
|
:%%%: .%%%%.
|
||||||
|
%%%%.-===::%%%%*
|
||||||
|
=%%%%+++++++%%%#.
|
||||||
|
-: .%%%#====+++#%%# .-
|
||||||
|
.+***= . =++. : .=*+#%*= :***.
|
||||||
|
=+****+++==:%+#=+% *##%%%%*=##*#**-=
|
||||||
|
++***+**+=: ##.. +##%%########**++
|
||||||
|
.++*****#*+- :*:++ ##%#%%%%%####**++
|
||||||
|
.++***+**++++- :#%%%%%####*##***+=
|
||||||
|
.+++***+#+++*########%%%##%#+*****++:
|
||||||
|
.=++++++*+++##%##%%####%%##*:+****+=
|
||||||
|
=++++++====*#%%#%###%%###- +***+++.
|
||||||
|
.+*++++= =+==##########= :****++.
|
||||||
|
=++*+:. .:=#####= .++**++-
|
||||||
|
.****: . -+**++=
|
||||||
|
*###= .****==
|
||||||
|
.#*#- **#*:
|
||||||
|
-### -*##.
|
||||||
|
+*= *#*
|
||||||
|
"""
|
@ -1,18 +1,14 @@
|
|||||||
"""BatDetect2 command line interface."""
|
"""BatDetect2 command line interface."""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cli",
|
"cli",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
|
|
||||||
INFO_STR = """
|
INFO_STR = """
|
||||||
BatDetect2 - Detection and Classification
|
BatDetect2 - Detection and Classification
|
||||||
Assumes audio files are mono, not stereo.
|
Assumes audio files are mono, not stereo.
|
||||||
@ -25,3 +21,4 @@ BatDetect2 - Detection and Classification
|
|||||||
def cli():
|
def cli():
|
||||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||||
click.echo(INFO_STR)
|
click.echo(INFO_STR)
|
||||||
|
# click.echo(BATDETECT_ASCII_ART)
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
"""BatDetect2 command line interface."""
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||||
from batdetect2.types import ProcessingConfiguration
|
from batdetect2.types import ProcessingConfiguration
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument(
|
@click.argument(
|
||||||
@ -114,10 +111,9 @@ def detect(
|
|||||||
):
|
):
|
||||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||||
save_results_to_file(results, results_path)
|
save_results_to_file(results, results_path)
|
||||||
except (RuntimeError, ValueError, LookupError) as err:
|
except (RuntimeError, ValueError, LookupError, EOFError) as err:
|
||||||
error_files.append(audio_file)
|
error_files.append(audio_file)
|
||||||
click.secho(f"Error processing file!: {err}", fg="red")
|
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
|
||||||
raise err
|
|
||||||
|
|
||||||
click.echo(f"\nResults saved to: {ann_dir}")
|
click.echo(f"\nResults saved to: {ann_dir}")
|
||||||
|
|
||||||
|
40
batdetect2/cli/data.py
Normal file
40
batdetect2/cli/data.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
|
||||||
|
__all__ = ["data"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.group()
|
||||||
|
def data(): ...
|
||||||
|
|
||||||
|
|
||||||
|
@data.command()
|
||||||
|
@click.argument(
|
||||||
|
"dataset_config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--field",
|
||||||
|
type=str,
|
||||||
|
help="If the dataset info is in a nested field please specify here.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="The base directory to which all recording and annotations paths are relative to.",
|
||||||
|
)
|
||||||
|
def summary(
|
||||||
|
dataset_config: Path,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
base_dir = base_dir or Path.cwd()
|
||||||
|
dataset = load_dataset_from_config(
|
||||||
|
dataset_config, field=field, base_dir=base_dir
|
||||||
|
)
|
||||||
|
print(f"Number of annotated clips: {len(dataset.clip_annotations)}")
|
188
batdetect2/cli/train.py
Normal file
188
batdetect2/cli/train.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.preprocess import (
|
||||||
|
load_preprocessing_config,
|
||||||
|
)
|
||||||
|
from batdetect2.train import (
|
||||||
|
load_label_config,
|
||||||
|
load_target_config,
|
||||||
|
preprocess_annotations,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["train"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.group()
|
||||||
|
def train(): ...
|
||||||
|
|
||||||
|
|
||||||
|
@train.command()
|
||||||
|
@click.argument(
|
||||||
|
"dataset_config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"output",
|
||||||
|
type=click.Path(),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dataset-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Specifies the key to access the dataset information within the "
|
||||||
|
"dataset configuration file, if the information is nested inside a "
|
||||||
|
"dictionary. If the dataset information is at the top level of the "
|
||||||
|
"config file, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"The main directory where your audio recordings and annotation "
|
||||||
|
"files are stored. This helps the program find your data, "
|
||||||
|
"especially if the paths in your dataset configuration file "
|
||||||
|
"are relative."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--preprocess-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the preprocessing configuration file. This file tells "
|
||||||
|
"the program how to prepare your audio data before training, such "
|
||||||
|
"as resampling or applying filters."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--preprocess-config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the preprocessing settings are inside a nested dictionary "
|
||||||
|
"within the preprocessing configuration file, specify the key "
|
||||||
|
"here to access them. If the preprocessing settings are at the "
|
||||||
|
"top level, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--label-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the label generation configuration file. This file "
|
||||||
|
"contains settings for how to create labels from your "
|
||||||
|
"annotations, which the model uses to learn."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--label-config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the label generation settings are inside a nested dictionary "
|
||||||
|
"within the label configuration file, specify the key here. If "
|
||||||
|
"the settings are at the top level, leave this blank."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--target-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the training target configuration file. This file "
|
||||||
|
"specifies what sounds the model should learn to predict."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--target-config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the target settings are inside a nested dictionary "
|
||||||
|
"within the target configuration file, specify the key here. "
|
||||||
|
"If the settings are at the top level, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--force",
|
||||||
|
is_flag=True,
|
||||||
|
help=(
|
||||||
|
"If a preprocessed file already exists, this option tells the "
|
||||||
|
"program to overwrite it with the new preprocessed data. Use "
|
||||||
|
"this if you want to re-do the preprocessing even if the files "
|
||||||
|
"already exist."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
help=(
|
||||||
|
"The maximum number of computer cores to use when processing "
|
||||||
|
"your audio data. Using more cores can speed up the preprocessing, "
|
||||||
|
"but don't use more than your computer has available. By default, "
|
||||||
|
"the program will use all available cores."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def preprocess(
|
||||||
|
dataset_config: Path,
|
||||||
|
output: Path,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
preprocess_config: Optional[Path] = None,
|
||||||
|
target_config: Optional[Path] = None,
|
||||||
|
label_config: Optional[Path] = None,
|
||||||
|
force: bool = False,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
target_config_field: Optional[str] = None,
|
||||||
|
preprocess_config_field: Optional[str] = None,
|
||||||
|
label_config_field: Optional[str] = None,
|
||||||
|
dataset_field: Optional[str] = None,
|
||||||
|
):
|
||||||
|
output = Path(output)
|
||||||
|
base_dir = base_dir or Path.cwd()
|
||||||
|
|
||||||
|
preprocess = (
|
||||||
|
load_preprocessing_config(
|
||||||
|
preprocess_config,
|
||||||
|
field=preprocess_config_field,
|
||||||
|
)
|
||||||
|
if preprocess_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
target = (
|
||||||
|
load_target_config(
|
||||||
|
target_config,
|
||||||
|
field=target_config_field,
|
||||||
|
)
|
||||||
|
if target_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
label = (
|
||||||
|
load_label_config(
|
||||||
|
label_config,
|
||||||
|
field=label_config_field,
|
||||||
|
)
|
||||||
|
if label_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset_from_config(
|
||||||
|
dataset_config,
|
||||||
|
field=dataset_field,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not output.exists():
|
||||||
|
output.mkdir(parents=True)
|
||||||
|
|
||||||
|
preprocess_annotations(
|
||||||
|
dataset.clip_annotations,
|
||||||
|
output_dir=output,
|
||||||
|
replace=force,
|
||||||
|
preprocessing_config=preprocess,
|
||||||
|
label_config=label,
|
||||||
|
target_config=target,
|
||||||
|
max_workers=num_workers,
|
||||||
|
)
|
0
batdetect2/compat/__init__.py
Normal file
0
batdetect2/compat/__init__.py
Normal file
@ -9,15 +9,16 @@ import numpy as np
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
from batdetect2 import types
|
from batdetect2 import types
|
||||||
from batdetect2.data.labels import ClassMapper
|
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_to_annotation_group",
|
"convert_to_annotation_group",
|
||||||
"load_annotation_project",
|
"load_file_annotation",
|
||||||
|
"annotation_to_sound_event",
|
||||||
]
|
]
|
||||||
|
|
||||||
SPECIES_TAG_KEY = "species"
|
SPECIES_TAG_KEY = "species"
|
||||||
@ -195,18 +196,30 @@ def annotation_to_sound_event(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key=label_key, value=annotation.label),
|
data.Tag(
|
||||||
data.Tag(key=event_key, value=annotation.event),
|
term=data.term_from_key(label_key),
|
||||||
data.Tag(key=individual_key, value=str(annotation.individual)),
|
value=annotation.label,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(event_key),
|
||||||
|
value=annotation.event,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(individual_key),
|
||||||
|
value=str(annotation.individual),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def file_annotation_to_clip(
|
def file_annotation_to_clip(
|
||||||
file_annotation: FileAnnotation,
|
file_annotation: FileAnnotation,
|
||||||
audio_dir: PathLike = Path.cwd(),
|
audio_dir: Optional[PathLike] = None,
|
||||||
|
label_key: str = "class",
|
||||||
) -> data.Clip:
|
) -> data.Clip:
|
||||||
"""Convert file annotation to recording."""
|
"""Convert file annotation to recording."""
|
||||||
|
audio_dir = audio_dir or Path.cwd()
|
||||||
|
|
||||||
full_path = Path(audio_dir) / file_annotation.id
|
full_path = Path(audio_dir) / file_annotation.id
|
||||||
|
|
||||||
if not full_path.exists():
|
if not full_path.exists():
|
||||||
@ -215,6 +228,12 @@ def file_annotation_to_clip(
|
|||||||
recording = data.Recording.from_file(
|
recording = data.Recording.from_file(
|
||||||
full_path,
|
full_path,
|
||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(label_key),
|
||||||
|
value=file_annotation.label,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return data.Clip(
|
return data.Clip(
|
||||||
@ -241,7 +260,11 @@ def file_annotation_to_clip_annotation(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||||
clip=clip,
|
clip=clip,
|
||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(label_key), value=file_annotation.label
|
||||||
|
)
|
||||||
|
],
|
||||||
sound_events=[
|
sound_events=[
|
||||||
annotation_to_sound_event(
|
annotation_to_sound_event(
|
||||||
annotation,
|
annotation,
|
||||||
@ -281,52 +304,3 @@ def list_file_annotations(path: PathLike) -> List[Path]:
|
|||||||
"""List all annotations in a directory."""
|
"""List all annotations in a directory."""
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
return [file for file in path.glob("*.json")]
|
return [file for file in path.glob("*.json")]
|
||||||
|
|
||||||
|
|
||||||
def load_annotation_project(
|
|
||||||
path: PathLike,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
audio_dir: PathLike = Path.cwd(),
|
|
||||||
) -> data.AnnotationProject:
|
|
||||||
"""Convert annotations to annotation project."""
|
|
||||||
paths = list_file_annotations(path)
|
|
||||||
|
|
||||||
if name is None:
|
|
||||||
name = str(path)
|
|
||||||
|
|
||||||
annotations = []
|
|
||||||
tasks = []
|
|
||||||
|
|
||||||
for p in paths:
|
|
||||||
try:
|
|
||||||
file_annotation = load_file_annotation(p)
|
|
||||||
except FileNotFoundError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
clip = file_annotation_to_clip(
|
|
||||||
file_annotation,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
annotations.append(
|
|
||||||
file_annotation_to_clip_annotation(
|
|
||||||
file_annotation,
|
|
||||||
clip,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
tasks.append(
|
|
||||||
file_annotation_to_annotation_task(
|
|
||||||
file_annotation,
|
|
||||||
clip,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return data.AnnotationProject(
|
|
||||||
name=name,
|
|
||||||
clip_annotations=annotations,
|
|
||||||
tasks=tasks,
|
|
||||||
)
|
|
151
batdetect2/compat/params.py
Normal file
151
batdetect2/compat/params.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from batdetect2.preprocess import (
|
||||||
|
AmplitudeScaleConfig,
|
||||||
|
AudioConfig,
|
||||||
|
FrequencyConfig,
|
||||||
|
LogScaleConfig,
|
||||||
|
PcenScaleConfig,
|
||||||
|
PreprocessingConfig,
|
||||||
|
ResampleConfig,
|
||||||
|
Scales,
|
||||||
|
SpecSizeConfig,
|
||||||
|
SpectrogramConfig,
|
||||||
|
STFTConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||||
|
from batdetect2.terms import TagInfo
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
HeatmapsConfig,
|
||||||
|
TargetConfig,
|
||||||
|
TrainPreprocessingConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_spectrogram_scale(scale: str) -> Scales:
|
||||||
|
if scale == "pcen":
|
||||||
|
return PcenScaleConfig()
|
||||||
|
if scale == "log":
|
||||||
|
return LogScaleConfig()
|
||||||
|
return AmplitudeScaleConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||||
|
return PreprocessingConfig(
|
||||||
|
audio=AudioConfig(
|
||||||
|
resample=ResampleConfig(
|
||||||
|
samplerate=params["target_samp_rate"],
|
||||||
|
mode="poly",
|
||||||
|
),
|
||||||
|
scale=params["scale_raw_audio"],
|
||||||
|
center=params["scale_raw_audio"],
|
||||||
|
duration=None,
|
||||||
|
),
|
||||||
|
spectrogram=SpectrogramConfig(
|
||||||
|
stft=STFTConfig(
|
||||||
|
window_duration=params["fft_win_length"],
|
||||||
|
window_overlap=params["fft_overlap"],
|
||||||
|
window_fn="hann",
|
||||||
|
),
|
||||||
|
frequencies=FrequencyConfig(
|
||||||
|
min_freq=params["min_freq"],
|
||||||
|
max_freq=params["max_freq"],
|
||||||
|
),
|
||||||
|
scale=get_spectrogram_scale(params["spec_scale"]),
|
||||||
|
denoise=params["denoise_spec_avg"],
|
||||||
|
size=SpecSizeConfig(
|
||||||
|
height=params["spec_height"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
),
|
||||||
|
max_scale=params["max_scale_spec"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_preprocessing_config(
|
||||||
|
params: dict,
|
||||||
|
) -> TrainPreprocessingConfig:
|
||||||
|
generic = params["generic_class"][0]
|
||||||
|
preprocessing = get_preprocessing_config(params)
|
||||||
|
|
||||||
|
freq_bin_width, time_bin_width = get_spectrogram_resolution(
|
||||||
|
preprocessing.spectrogram
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainPreprocessingConfig(
|
||||||
|
preprocessing=preprocessing,
|
||||||
|
target=TargetConfig(
|
||||||
|
classes=[
|
||||||
|
TagInfo(key="class", value=class_name, label=class_name)
|
||||||
|
for class_name in params["class_names"]
|
||||||
|
],
|
||||||
|
generic_class=TagInfo(
|
||||||
|
key="class",
|
||||||
|
value=generic,
|
||||||
|
label=generic,
|
||||||
|
),
|
||||||
|
include=[
|
||||||
|
TagInfo(key="event", value=event)
|
||||||
|
for event in params["events_of_interest"]
|
||||||
|
],
|
||||||
|
exclude=[
|
||||||
|
TagInfo(key="class", value=value)
|
||||||
|
for value in params["classes_to_ignore"]
|
||||||
|
],
|
||||||
|
),
|
||||||
|
heatmaps=HeatmapsConfig(
|
||||||
|
position="bottom-left",
|
||||||
|
time_scale=1 / time_bin_width,
|
||||||
|
frequency_scale=1 / freq_bin_width,
|
||||||
|
sigma=params["target_sigma"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 'standardize_classs_names_ip',
|
||||||
|
# 'convert_to_genus',
|
||||||
|
# 'genus_mapping',
|
||||||
|
# 'standardize_classs_names',
|
||||||
|
# 'genus_names',
|
||||||
|
|
||||||
|
# ['data_dir',
|
||||||
|
# 'ann_dir',
|
||||||
|
# 'train_split',
|
||||||
|
# 'model_name',
|
||||||
|
# 'num_filters',
|
||||||
|
# 'experiment',
|
||||||
|
# 'model_file_name',
|
||||||
|
# 'op_im_dir',
|
||||||
|
# 'op_im_dir_test',
|
||||||
|
# 'notes',
|
||||||
|
# 'spec_divide_factor',
|
||||||
|
# 'detection_overlap',
|
||||||
|
# 'ignore_start_end',
|
||||||
|
# 'detection_threshold',
|
||||||
|
# 'nms_kernel_size',
|
||||||
|
# 'nms_top_k_per_sec',
|
||||||
|
# 'aug_prob',
|
||||||
|
# 'augment_at_train',
|
||||||
|
# 'augment_at_train_combine',
|
||||||
|
# 'echo_max_delay',
|
||||||
|
# 'stretch_squeeze_delta',
|
||||||
|
# 'mask_max_time_perc',
|
||||||
|
# 'mask_max_freq_perc',
|
||||||
|
# 'spec_amp_scaling',
|
||||||
|
# 'aug_sampling_rates',
|
||||||
|
# 'train_loss',
|
||||||
|
# 'det_loss_weight',
|
||||||
|
# 'size_loss_weight',
|
||||||
|
# 'class_loss_weight',
|
||||||
|
# 'individual_loss_weight',
|
||||||
|
# 'emb_dim',
|
||||||
|
# 'lr',
|
||||||
|
# 'batch_size',
|
||||||
|
# 'num_workers',
|
||||||
|
# 'num_epochs',
|
||||||
|
# 'num_eval_epochs',
|
||||||
|
# 'device',
|
||||||
|
# 'save_test_image_during_train',
|
||||||
|
# 'save_test_image_after_train',
|
||||||
|
# 'train_sets',
|
||||||
|
# 'test_sets',
|
||||||
|
# 'class_inv_freq',
|
||||||
|
# 'ip_height']
|
35
batdetect2/configs.py
Normal file
35
batdetect2/configs.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from typing import Any, Optional, Type, TypeVar
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfig(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_field(obj: dict, field: str) -> Any:
|
||||||
|
if "." not in field:
|
||||||
|
return obj[field]
|
||||||
|
|
||||||
|
field, rest = field.split(".", 1)
|
||||||
|
subobj = obj[field]
|
||||||
|
return get_object_field(subobj, rest)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(
|
||||||
|
path: PathLike,
|
||||||
|
schema: Type[T],
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> T:
|
||||||
|
with open(path, "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
if field:
|
||||||
|
config = get_object_field(config, field)
|
||||||
|
|
||||||
|
return schema.model_validate(config)
|
@ -0,0 +1,14 @@
|
|||||||
|
from batdetect2.data.annotations import (
|
||||||
|
AnnotatedDataset,
|
||||||
|
load_annotated_dataset,
|
||||||
|
)
|
||||||
|
from batdetect2.data.data import load_dataset, load_dataset_from_config
|
||||||
|
from batdetect2.data.types import Dataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnnotatedDataset",
|
||||||
|
"Dataset",
|
||||||
|
"load_annotated_dataset",
|
||||||
|
"load_dataset",
|
||||||
|
"load_dataset_from_config",
|
||||||
|
]
|
36
batdetect2/data/annotations.py
Normal file
36
batdetect2/data/annotations.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AOEFAnnotationFile",
|
||||||
|
"AnnotationFormats",
|
||||||
|
"BatDetect2AnnotationFile",
|
||||||
|
"BatDetect2AnnotationFiles",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2AnnotationFiles(BaseConfig):
|
||||||
|
format: Literal["batdetect2"] = "batdetect2"
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2AnnotationFile(BaseConfig):
|
||||||
|
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class AOEFAnnotationFile(BaseConfig):
|
||||||
|
format: Literal["aoef"] = "aoef"
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
AnnotationFormats = Union[
|
||||||
|
BatDetect2AnnotationFiles,
|
||||||
|
BatDetect2AnnotationFile,
|
||||||
|
AOEFAnnotationFile,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
55
batdetect2/data/annotations/__init__.py
Normal file
55
batdetect2/data/annotations/__init__.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.annotations.aeof import (
|
||||||
|
AOEFAnnotations,
|
||||||
|
load_aoef_annotated_dataset,
|
||||||
|
)
|
||||||
|
from batdetect2.data.annotations.batdetect2_files import (
|
||||||
|
BatDetect2FilesAnnotations,
|
||||||
|
load_batdetect2_files_annotated_dataset,
|
||||||
|
)
|
||||||
|
from batdetect2.data.annotations.batdetect2_merged import (
|
||||||
|
BatDetect2MergedAnnotations,
|
||||||
|
load_batdetect2_merged_annotated_dataset,
|
||||||
|
)
|
||||||
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_annotated_dataset",
|
||||||
|
"AnnotatedDataset",
|
||||||
|
"AOEFAnnotations",
|
||||||
|
"BatDetect2FilesAnnotations",
|
||||||
|
"BatDetect2MergedAnnotations",
|
||||||
|
"AnnotationFormats",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
AnnotationFormats = Union[
|
||||||
|
BatDetect2MergedAnnotations,
|
||||||
|
BatDetect2FilesAnnotations,
|
||||||
|
AOEFAnnotations,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def load_annotated_dataset(
|
||||||
|
dataset: AnnotatedDataset,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
) -> data.AnnotationSet:
|
||||||
|
if isinstance(dataset, AOEFAnnotations):
|
||||||
|
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||||
|
|
||||||
|
if isinstance(dataset, BatDetect2MergedAnnotations):
|
||||||
|
return load_batdetect2_merged_annotated_dataset(
|
||||||
|
dataset, base_dir=base_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(dataset, BatDetect2FilesAnnotations):
|
||||||
|
return load_batdetect2_files_annotated_dataset(
|
||||||
|
dataset,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
|
37
batdetect2/data/annotations/aeof.py
Normal file
37
batdetect2/data/annotations/aeof.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from soundevent import data, io
|
||||||
|
|
||||||
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AOEFAnnotations",
|
||||||
|
"load_aoef_annotated_dataset",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AOEFAnnotations(AnnotatedDataset):
|
||||||
|
format: Literal["aoef"] = "aoef"
|
||||||
|
annotations_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
def load_aoef_annotated_dataset(
|
||||||
|
dataset: AOEFAnnotations,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
) -> data.AnnotationSet:
|
||||||
|
audio_dir = dataset.audio_dir
|
||||||
|
path = dataset.annotations_path
|
||||||
|
|
||||||
|
if base_dir:
|
||||||
|
audio_dir = base_dir / audio_dir
|
||||||
|
path = base_dir / path
|
||||||
|
|
||||||
|
loaded = io.load(path, audio_dir=audio_dir)
|
||||||
|
|
||||||
|
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The AOEF file at {path} does not contain a set of annotations"
|
||||||
|
)
|
||||||
|
|
||||||
|
return loaded
|
80
batdetect2/data/annotations/batdetect2_files.py
Normal file
80
batdetect2/data/annotations/batdetect2_files.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.annotations.legacy import (
|
||||||
|
file_annotation_to_annotation_task,
|
||||||
|
file_annotation_to_clip,
|
||||||
|
file_annotation_to_clip_annotation,
|
||||||
|
list_file_annotations,
|
||||||
|
load_file_annotation,
|
||||||
|
)
|
||||||
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_batdetect2_files_annotated_dataset",
|
||||||
|
"BatDetect2FilesAnnotations",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2FilesAnnotations(AnnotatedDataset):
|
||||||
|
format: Literal["batdetect2"] = "batdetect2"
|
||||||
|
annotations_dir: Path
|
||||||
|
|
||||||
|
|
||||||
|
def load_batdetect2_files_annotated_dataset(
|
||||||
|
dataset: BatDetect2FilesAnnotations,
|
||||||
|
base_dir: Optional[PathLike] = None,
|
||||||
|
) -> data.AnnotationProject:
|
||||||
|
"""Convert annotations to annotation project."""
|
||||||
|
audio_dir = dataset.audio_dir
|
||||||
|
path = dataset.annotations_dir
|
||||||
|
|
||||||
|
if base_dir:
|
||||||
|
audio_dir = base_dir / audio_dir
|
||||||
|
path = base_dir / path
|
||||||
|
|
||||||
|
paths = list_file_annotations(path)
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for p in paths:
|
||||||
|
try:
|
||||||
|
file_annotation = load_file_annotation(p)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip = file_annotation_to_clip(
|
||||||
|
file_annotation,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotations.append(
|
||||||
|
file_annotation_to_clip_annotation(
|
||||||
|
file_annotation,
|
||||||
|
clip,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
file_annotation_to_annotation_task(
|
||||||
|
file_annotation,
|
||||||
|
clip,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.AnnotationProject(
|
||||||
|
name=dataset.name,
|
||||||
|
description=dataset.description,
|
||||||
|
clip_annotations=annotations,
|
||||||
|
tasks=tasks,
|
||||||
|
)
|
64
batdetect2/data/annotations/batdetect2_merged.py
Normal file
64
batdetect2/data/annotations/batdetect2_merged.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.annotations.legacy import (
|
||||||
|
FileAnnotation,
|
||||||
|
file_annotation_to_annotation_task,
|
||||||
|
file_annotation_to_clip,
|
||||||
|
file_annotation_to_clip_annotation,
|
||||||
|
)
|
||||||
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BatDetect2MergedAnnotations",
|
||||||
|
"load_batdetect2_merged_annotated_dataset",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2MergedAnnotations(AnnotatedDataset):
|
||||||
|
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||||
|
annotations_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
def load_batdetect2_merged_annotated_dataset(
|
||||||
|
dataset: BatDetect2MergedAnnotations,
|
||||||
|
base_dir: Optional[PathLike] = None,
|
||||||
|
) -> data.AnnotationProject:
|
||||||
|
audio_dir = dataset.audio_dir
|
||||||
|
path = dataset.annotations_path
|
||||||
|
|
||||||
|
if base_dir:
|
||||||
|
audio_dir = base_dir / audio_dir
|
||||||
|
path = base_dir / path
|
||||||
|
|
||||||
|
content = json.loads(Path(path).read_text())
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for ann in content:
|
||||||
|
try:
|
||||||
|
ann = FileAnnotation.model_validate(ann)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||||
|
tasks.append(file_annotation_to_annotation_task(ann, clip))
|
||||||
|
|
||||||
|
return data.AnnotationProject(
|
||||||
|
name=dataset.name,
|
||||||
|
description=dataset.description,
|
||||||
|
clip_annotations=annotations,
|
||||||
|
tasks=tasks,
|
||||||
|
)
|
304
batdetect2/data/annotations/legacy.py
Normal file
304
batdetect2/data/annotations/legacy.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
"""Compatibility functions between old and new data structures."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
|
from batdetect2 import types
|
||||||
|
|
||||||
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_to_annotation_group",
|
||||||
|
]
|
||||||
|
|
||||||
|
SPECIES_TAG_KEY = "species"
|
||||||
|
ECHOLOCATION_EVENT = "Echolocation"
|
||||||
|
UNKNOWN_CLASS = "__UNKNOWN__"
|
||||||
|
|
||||||
|
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||||
|
|
||||||
|
|
||||||
|
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
|
||||||
|
ClassFn = Callable[[data.Recording], int]
|
||||||
|
|
||||||
|
IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
||||||
|
|
||||||
|
|
||||||
|
def get_recording_class_name(recording: data.Recording) -> str:
|
||||||
|
"""Get the class name for a recording."""
|
||||||
|
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
|
||||||
|
if tag is None:
|
||||||
|
return UNKNOWN_CLASS
|
||||||
|
return tag.value
|
||||||
|
|
||||||
|
|
||||||
|
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
|
||||||
|
"""Get the notes for a ClipAnnotation."""
|
||||||
|
all_notes = [
|
||||||
|
*annotation.notes,
|
||||||
|
*annotation.clip.recording.notes,
|
||||||
|
]
|
||||||
|
messages = [note.message for note in all_notes if note.message is not None]
|
||||||
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_annotation_group(
|
||||||
|
annotation: data.ClipAnnotation,
|
||||||
|
class_mapper: ClassMapper,
|
||||||
|
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||||
|
class_fn: ClassFn = lambda _: 0,
|
||||||
|
individual_fn: IndividualFn = lambda _: 0,
|
||||||
|
) -> types.AudioLoaderAnnotationGroup:
|
||||||
|
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
|
||||||
|
recording = annotation.clip.recording
|
||||||
|
|
||||||
|
start_times = []
|
||||||
|
end_times = []
|
||||||
|
low_freqs = []
|
||||||
|
high_freqs = []
|
||||||
|
class_ids = []
|
||||||
|
x_inds = []
|
||||||
|
y_inds = []
|
||||||
|
individual_ids = []
|
||||||
|
annotations: List[types.Annotation] = []
|
||||||
|
class_id_file = class_fn(recording)
|
||||||
|
|
||||||
|
for sound_event in annotation.sound_events:
|
||||||
|
geometry = sound_event.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
class_id = class_mapper.transform(sound_event) or -1
|
||||||
|
event = event_fn(sound_event) or ""
|
||||||
|
individual_id = individual_fn(sound_event) or -1
|
||||||
|
|
||||||
|
start_times.append(start_time)
|
||||||
|
end_times.append(end_time)
|
||||||
|
low_freqs.append(low_freq)
|
||||||
|
high_freqs.append(high_freq)
|
||||||
|
class_ids.append(class_id)
|
||||||
|
individual_ids.append(individual_id)
|
||||||
|
|
||||||
|
# NOTE: This will be computed later so we just put a placeholder
|
||||||
|
# here for now.
|
||||||
|
x_inds.append(0)
|
||||||
|
y_inds.append(0)
|
||||||
|
|
||||||
|
annotations.append(
|
||||||
|
{
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"low_freq": low_freq,
|
||||||
|
"high_freq": high_freq,
|
||||||
|
"class_prob": 1.0,
|
||||||
|
"det_prob": 1.0,
|
||||||
|
"individual": "0",
|
||||||
|
"event": event,
|
||||||
|
"class_id": class_id, # type: ignore
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(recording.path),
|
||||||
|
"duration": recording.duration,
|
||||||
|
"issues": False,
|
||||||
|
"file_path": str(recording.path),
|
||||||
|
"time_exp": recording.time_expansion,
|
||||||
|
"class_name": get_recording_class_name(recording),
|
||||||
|
"notes": get_annotation_notes(annotation),
|
||||||
|
"annotated": True,
|
||||||
|
"start_times": np.array(start_times),
|
||||||
|
"end_times": np.array(end_times),
|
||||||
|
"low_freqs": np.array(low_freqs),
|
||||||
|
"high_freqs": np.array(high_freqs),
|
||||||
|
"class_ids": np.array(class_ids),
|
||||||
|
"x_inds": np.array(x_inds),
|
||||||
|
"y_inds": np.array(y_inds),
|
||||||
|
"individual_ids": np.array(individual_ids),
|
||||||
|
"annotation": annotations,
|
||||||
|
"class_id_file": class_id_file,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Annotation(BaseModel):
|
||||||
|
"""Annotation class to hold batdetect annotations."""
|
||||||
|
|
||||||
|
label: str = Field(alias="class")
|
||||||
|
event: str
|
||||||
|
individual: int = 0
|
||||||
|
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
low_freq: float
|
||||||
|
high_freq: float
|
||||||
|
|
||||||
|
|
||||||
|
class FileAnnotation(BaseModel):
|
||||||
|
"""FileAnnotation class to hold batdetect annotations for a file."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
duration: float
|
||||||
|
time_exp: float = 1
|
||||||
|
|
||||||
|
label: str = Field(alias="class_name")
|
||||||
|
|
||||||
|
annotation: List[Annotation]
|
||||||
|
|
||||||
|
annotated: bool = False
|
||||||
|
issues: bool = False
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_annotation(path: PathLike) -> FileAnnotation:
|
||||||
|
"""Load annotation from batdetect format."""
|
||||||
|
path = Path(path)
|
||||||
|
return FileAnnotation.model_validate_json(path.read_text())
|
||||||
|
|
||||||
|
|
||||||
|
def annotation_to_sound_event(
|
||||||
|
annotation: Annotation,
|
||||||
|
recording: data.Recording,
|
||||||
|
label_key: str = "class",
|
||||||
|
event_key: str = "event",
|
||||||
|
individual_key: str = "individual",
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
"""Convert annotation to sound event annotation."""
|
||||||
|
sound_event = data.SoundEvent(
|
||||||
|
uuid=uuid.uuid5(
|
||||||
|
NAMESPACE,
|
||||||
|
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
||||||
|
),
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
annotation.start_time,
|
||||||
|
annotation.low_freq,
|
||||||
|
annotation.end_time,
|
||||||
|
annotation.high_freq,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(label_key),
|
||||||
|
value=annotation.label,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(event_key),
|
||||||
|
value=annotation.event,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(individual_key),
|
||||||
|
value=str(annotation.individual),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def file_annotation_to_clip(
|
||||||
|
file_annotation: FileAnnotation,
|
||||||
|
audio_dir: Optional[PathLike] = None,
|
||||||
|
label_key: str = "class",
|
||||||
|
) -> data.Clip:
|
||||||
|
"""Convert file annotation to recording."""
|
||||||
|
audio_dir = audio_dir or Path.cwd()
|
||||||
|
|
||||||
|
full_path = Path(audio_dir) / file_annotation.id
|
||||||
|
|
||||||
|
if not full_path.exists():
|
||||||
|
raise FileNotFoundError(f"File {full_path} not found.")
|
||||||
|
|
||||||
|
recording = data.Recording.from_file(
|
||||||
|
full_path,
|
||||||
|
time_expansion=file_annotation.time_exp,
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(label_key),
|
||||||
|
value=file_annotation.label,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.Clip(
|
||||||
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def file_annotation_to_clip_annotation(
|
||||||
|
file_annotation: FileAnnotation,
|
||||||
|
clip: data.Clip,
|
||||||
|
label_key: str = "class",
|
||||||
|
event_key: str = "event",
|
||||||
|
individual_key: str = "individual",
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
"""Convert file annotation to clip annotation."""
|
||||||
|
notes = []
|
||||||
|
if file_annotation.notes:
|
||||||
|
notes.append(data.Note(message=file_annotation.notes))
|
||||||
|
|
||||||
|
return data.ClipAnnotation(
|
||||||
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||||
|
clip=clip,
|
||||||
|
notes=notes,
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(label_key), value=file_annotation.label
|
||||||
|
)
|
||||||
|
],
|
||||||
|
sound_events=[
|
||||||
|
annotation_to_sound_event(
|
||||||
|
annotation,
|
||||||
|
clip.recording,
|
||||||
|
label_key=label_key,
|
||||||
|
event_key=event_key,
|
||||||
|
individual_key=individual_key,
|
||||||
|
)
|
||||||
|
for annotation in file_annotation.annotation
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def file_annotation_to_annotation_task(
|
||||||
|
file_annotation: FileAnnotation,
|
||||||
|
clip: data.Clip,
|
||||||
|
) -> data.AnnotationTask:
|
||||||
|
status_badges = []
|
||||||
|
|
||||||
|
if file_annotation.issues:
|
||||||
|
status_badges.append(
|
||||||
|
data.StatusBadge(state=data.AnnotationState.rejected)
|
||||||
|
)
|
||||||
|
elif file_annotation.annotated:
|
||||||
|
status_badges.append(
|
||||||
|
data.StatusBadge(state=data.AnnotationState.completed)
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.AnnotationTask(
|
||||||
|
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
|
||||||
|
clip=clip,
|
||||||
|
status_badges=status_badges,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_file_annotations(path: PathLike) -> List[Path]:
|
||||||
|
"""List all annotations in a directory."""
|
||||||
|
path = Path(path)
|
||||||
|
return [file for file in path.glob("*.json")]
|
41
batdetect2/data/annotations/types.py
Normal file
41
batdetect2/data/annotations/types.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnnotatedDataset",
|
||||||
|
"BatDetect2MergedAnnotations",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotatedDataset(BaseConfig):
|
||||||
|
"""Represents a single, cohesive source of audio recordings and annotations.
|
||||||
|
|
||||||
|
A source typically groups recordings originating from a specific context,
|
||||||
|
such as a single project, site, deployment, or recordist. All audio files
|
||||||
|
belonging to a source should be located within a single directory,
|
||||||
|
specified by `audio_dir`.
|
||||||
|
|
||||||
|
Annotations associated with these recordings are defined by the
|
||||||
|
`annotations` field, which supports various formats (e.g., AOEF files,
|
||||||
|
specific CSV
|
||||||
|
structures).
|
||||||
|
Crucially, file paths referenced within the annotation data *must* be
|
||||||
|
relative to the `audio_dir`. This ensures that the dataset definition
|
||||||
|
remains portable across different systems and base directories.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: A unique identifier for this data source.
|
||||||
|
description: Detailed information about the source, including recording
|
||||||
|
methods, annotation procedures, equipment used, potential biases,
|
||||||
|
or any important caveats for users.
|
||||||
|
audio_dir: The file system path to the directory containing the audio
|
||||||
|
recordings for this source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
audio_dir: Path
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
37
batdetect2/data/data.py
Normal file
37
batdetect2/data/data.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import load_config
|
||||||
|
from batdetect2.data.annotations import load_annotated_dataset
|
||||||
|
from batdetect2.data.types import Dataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_dataset",
|
||||||
|
"load_dataset_from_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(
|
||||||
|
dataset: Dataset,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
) -> data.AnnotationSet:
|
||||||
|
clip_annotations = []
|
||||||
|
for source in dataset.sources:
|
||||||
|
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||||
|
clip_annotations.extend(annotated_source.clip_annotations)
|
||||||
|
return data.AnnotationSet(clip_annotations=clip_annotations)
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
config = load_config(
|
||||||
|
path=path,
|
||||||
|
schema=Dataset,
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
|
return load_dataset(config, base_dir=base_dir)
|
@ -1,33 +0,0 @@
|
|||||||
from typing import Callable, Generic, Iterable, List, TypeVar
|
|
||||||
|
|
||||||
from soundevent import data
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ClipDataset",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
E = TypeVar("E")
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDataset(Dataset, Generic[E]):
|
|
||||||
clips: List[data.Clip]
|
|
||||||
|
|
||||||
transform: Callable[[data.Clip], E]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clips: Iterable[data.Clip],
|
|
||||||
transform: Callable[[data.Clip], E],
|
|
||||||
name: str = "ClipDataset",
|
|
||||||
):
|
|
||||||
self.clips = list(clips)
|
|
||||||
self.transform = transform
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.clips)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> E:
|
|
||||||
return self.transform(self.clips[idx])
|
|
@ -1,392 +0,0 @@
|
|||||||
"""Module containing functions for preprocessing audio clips."""
|
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import librosa.core.spectrum
|
|
||||||
import numpy as np
|
|
||||||
import xarray as xr
|
|
||||||
from numpy.typing import DTypeLike
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from scipy.signal import resample_poly
|
|
||||||
from soundevent import audio, data, arrays
|
|
||||||
from soundevent.arrays import operations as ops
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"PreprocessingConfig",
|
|
||||||
"preprocess_audio_clip",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256000
|
|
||||||
SCALE_RAW_AUDIO = False
|
|
||||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
|
||||||
FFT_OVERLAP = 0.75
|
|
||||||
MAX_FREQ_HZ = 120000
|
|
||||||
MIN_FREQ_HZ = 10000
|
|
||||||
DEFAULT_DURATION = 1
|
|
||||||
SPEC_HEIGHT = 128
|
|
||||||
SPEC_WIDTH = 256
|
|
||||||
SPEC_SCALE = "pcen"
|
|
||||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
|
||||||
DENOISE_SPEC_AVG = True
|
|
||||||
MAX_SCALE_SPEC = False
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseModel):
|
|
||||||
"""Configuration for preprocessing data."""
|
|
||||||
|
|
||||||
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
|
||||||
|
|
||||||
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
|
||||||
|
|
||||||
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
|
||||||
|
|
||||||
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
|
||||||
|
|
||||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
|
||||||
|
|
||||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
|
||||||
|
|
||||||
spec_scale: str = Field(default=SPEC_SCALE)
|
|
||||||
|
|
||||||
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
|
||||||
|
|
||||||
max_scale_spec: bool = MAX_SCALE_SPEC
|
|
||||||
|
|
||||||
duration: Optional[float] = DEFAULT_DURATION
|
|
||||||
|
|
||||||
spec_height: int = SPEC_HEIGHT
|
|
||||||
|
|
||||||
spec_time_period: float = SPEC_TIME_PERIOD
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(
|
|
||||||
cls,
|
|
||||||
path: Union[str, Path],
|
|
||||||
) -> "PreprocessingConfig":
|
|
||||||
"""Load configuration from a file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path
|
|
||||||
Path to the configuration file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
PreprocessingConfig
|
|
||||||
The configuration loaded from the file.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the configuration file does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the configuration file is invalid.
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.is_file():
|
|
||||||
raise FileNotFoundError(f"Config file not found: {path}")
|
|
||||||
|
|
||||||
return cls.model_validate_json(path.read_text())
|
|
||||||
|
|
||||||
def to_file(self, path: Union[str, Path]) -> None:
|
|
||||||
"""Save configuration to a file."""
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.parent.exists():
|
|
||||||
path.parent.mkdir(parents=True)
|
|
||||||
|
|
||||||
path.write_text(self.model_dump_json())
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audio_clip(
|
|
||||||
clip: data.Clip,
|
|
||||||
config: PreprocessingConfig = PreprocessingConfig(),
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Preprocesses audio clip to generate spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip
|
|
||||||
The audio clip to preprocess.
|
|
||||||
config
|
|
||||||
Configuration for preprocessing.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Preprocessed spectrogram.
|
|
||||||
|
|
||||||
"""
|
|
||||||
wav = load_clip_audio(
|
|
||||||
clip,
|
|
||||||
target_sampling_rate=config.target_samplerate,
|
|
||||||
scale=config.scale_audio,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = compute_spectrogram(
|
|
||||||
wav,
|
|
||||||
fft_win_length=config.fft_win_length,
|
|
||||||
fft_overlap=config.fft_overlap,
|
|
||||||
max_freq=config.max_freq,
|
|
||||||
min_freq=config.min_freq,
|
|
||||||
spec_scale=config.spec_scale,
|
|
||||||
denoise_spec_avg=config.denoise_spec_avg,
|
|
||||||
max_scale_spec=config.max_scale_spec,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.duration is not None:
|
|
||||||
spec = adjust_spec_duration(clip, spec, config.duration)
|
|
||||||
|
|
||||||
duration = arrays.get_dim_width(spec, dim="time")
|
|
||||||
return ops.resize(
|
|
||||||
spec,
|
|
||||||
time=int(np.ceil(duration / config.spec_time_period)),
|
|
||||||
frequency=config.spec_height,
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_spec_duration(
|
|
||||||
clip: data.Clip,
|
|
||||||
spec: xr.DataArray,
|
|
||||||
duration: float,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
current_duration = clip.end_time - clip.start_time
|
|
||||||
|
|
||||||
if current_duration == duration:
|
|
||||||
return spec
|
|
||||||
|
|
||||||
if current_duration > duration:
|
|
||||||
return arrays.crop_dim(
|
|
||||||
spec,
|
|
||||||
dim="time",
|
|
||||||
start=clip.start_time,
|
|
||||||
stop=clip.start_time + duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
return arrays.extend_dim(
|
|
||||||
spec,
|
|
||||||
dim="time",
|
|
||||||
start=clip.start_time,
|
|
||||||
stop=clip.start_time + duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_audio(
|
|
||||||
clip: data.Clip,
|
|
||||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
scale: bool = SCALE_RAW_AUDIO,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
|
||||||
|
|
||||||
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
|
||||||
|
|
||||||
if scale:
|
|
||||||
wav = ops.center(wav)
|
|
||||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
|
||||||
|
|
||||||
return wav.astype(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
|
||||||
wav: xr.DataArray,
|
|
||||||
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
if "time" not in wav.dims:
|
|
||||||
raise ValueError("Audio must have a time dimension")
|
|
||||||
|
|
||||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
|
||||||
|
|
||||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
|
||||||
step = arrays.get_dim_step(wav, dim="time")
|
|
||||||
original_samplerate = int(1 / step)
|
|
||||||
|
|
||||||
if original_samplerate == target_samplerate:
|
|
||||||
return wav.astype(dtype)
|
|
||||||
|
|
||||||
gcd = np.gcd(original_samplerate, target_samplerate)
|
|
||||||
resampled = resample_poly(
|
|
||||||
wav.values,
|
|
||||||
target_samplerate // gcd,
|
|
||||||
original_samplerate // gcd,
|
|
||||||
axis=time_axis,
|
|
||||||
)
|
|
||||||
|
|
||||||
resampled_times = np.linspace(
|
|
||||||
start,
|
|
||||||
stop + step,
|
|
||||||
len(resampled),
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=resampled.astype(dtype),
|
|
||||||
dims=wav.dims,
|
|
||||||
coords={
|
|
||||||
**wav.coords,
|
|
||||||
"time": arrays.create_time_dim_from_array(
|
|
||||||
resampled_times,
|
|
||||||
samplerate=target_samplerate,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
attrs=wav.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
|
||||||
wav: xr.DataArray,
|
|
||||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
|
||||||
fft_overlap: float = FFT_OVERLAP,
|
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
|
||||||
spec_scale: str = SPEC_SCALE,
|
|
||||||
denoise_spec_avg: bool = True,
|
|
||||||
max_scale_spec: bool = False,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
spec = gen_mag_spectrogram(
|
|
||||||
wav,
|
|
||||||
window_len=fft_win_length,
|
|
||||||
overlap_perc=fft_overlap,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = arrays.crop_dim(
|
|
||||||
spec,
|
|
||||||
dim="frequency",
|
|
||||||
start=min_freq,
|
|
||||||
stop=max_freq,
|
|
||||||
).astype(dtype)
|
|
||||||
|
|
||||||
spec = scale_spectrogram(spec, scale=spec_scale)
|
|
||||||
|
|
||||||
if denoise_spec_avg:
|
|
||||||
spec = denoise_spectrogram(spec)
|
|
||||||
|
|
||||||
if max_scale_spec:
|
|
||||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
|
||||||
|
|
||||||
return spec.astype(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram(
|
|
||||||
wave: xr.DataArray,
|
|
||||||
window_len: float,
|
|
||||||
overlap_perc: float,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
|
||||||
step = arrays.get_dim_step(wave, dim="time")
|
|
||||||
sampling_rate = 1 / step
|
|
||||||
|
|
||||||
hop_len = window_len * (1 - overlap_perc)
|
|
||||||
nfft = int(window_len * sampling_rate)
|
|
||||||
noverlap = int(overlap_perc * nfft)
|
|
||||||
|
|
||||||
# compute spec
|
|
||||||
spec, _ = librosa.core.spectrum._spectrogram(
|
|
||||||
y=wave.data,
|
|
||||||
power=1,
|
|
||||||
n_fft=nfft,
|
|
||||||
hop_length=nfft - noverlap,
|
|
||||||
center=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=spec.astype(dtype),
|
|
||||||
dims=["frequency", "time"],
|
|
||||||
coords={
|
|
||||||
"frequency": arrays.create_frequency_dim_from_array(
|
|
||||||
np.linspace(
|
|
||||||
0,
|
|
||||||
sampling_rate / 2,
|
|
||||||
spec.shape[0],
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
step=sampling_rate / nfft,
|
|
||||||
),
|
|
||||||
"time": arrays.create_time_dim_from_array(
|
|
||||||
np.linspace(
|
|
||||||
start_time,
|
|
||||||
end_time - (window_len - hop_len),
|
|
||||||
spec.shape[1],
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
step=hop_len,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
attrs={
|
|
||||||
**wave.attrs,
|
|
||||||
"original_samplerate": sampling_rate,
|
|
||||||
"nfft": nfft,
|
|
||||||
"noverlap": noverlap,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def denoise_spectrogram(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
return xr.DataArray(
|
|
||||||
data=(spec - spec.mean("time")).clip(0),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs=spec.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def scale_spectrogram(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
scale: str = SPEC_SCALE,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
|
||||||
|
|
||||||
if scale == "pcen":
|
|
||||||
smoothing_constant = get_pcen_smoothing_constant(samplerate / 10)
|
|
||||||
return audio.pcen(
|
|
||||||
spec * (2**31),
|
|
||||||
smooth=smoothing_constant,
|
|
||||||
).astype(dtype)
|
|
||||||
|
|
||||||
if scale == "log":
|
|
||||||
return log_scale(spec, dtype=dtype)
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def log_scale(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
|
||||||
nfft = spec.attrs["nfft"]
|
|
||||||
log_scaling = (
|
|
||||||
2.0
|
|
||||||
* (1.0 / samplerate)
|
|
||||||
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
|
||||||
)
|
|
||||||
return xr.DataArray(
|
|
||||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs=spec.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_pcen_smoothing_constant(
|
|
||||||
sr: int,
|
|
||||||
time_constant: float = 0.4,
|
|
||||||
hop_length: int = 512,
|
|
||||||
) -> float:
|
|
||||||
t_frames = time_constant * sr / float(hop_length)
|
|
||||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
|
29
batdetect2/data/types.py
Normal file
29
batdetect2/data/types.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Annotated, List
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.data.annotations import AnnotationFormats
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(BaseConfig):
|
||||||
|
"""Represents a collection of one or more DatasetSources.
|
||||||
|
|
||||||
|
In the context of batdetect2, a Dataset aggregates multiple `DatasetSource`
|
||||||
|
instances. It serves as the primary unit for defining data splits,
|
||||||
|
typically used for model training, validation, or testing phases.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: A descriptive name for the overall dataset
|
||||||
|
(e.g., "UK Training Set").
|
||||||
|
description: A detailed explanation of the dataset's purpose,
|
||||||
|
composition, how it was assembled, or any specific characteristics.
|
||||||
|
sources: A list containing the `DatasetSource` objects included in this
|
||||||
|
dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
sources: List[
|
||||||
|
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||||
|
]
|
@ -1,4 +1,5 @@
|
|||||||
"""Functions to compute features from predictions."""
|
"""Functions to compute features from predictions."""
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -219,7 +220,6 @@ def compute_call_interval(
|
|||||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: The order of the features in this dictionary is important. The
|
# NOTE: The order of the features in this dictionary is important. The
|
||||||
# features are extracted in this order and the order of the columns in the
|
# features are extracted in this order and the order of the columns in the
|
||||||
# output csv file is determined by this order. In order to avoid breaking
|
# output csv file is determined by this order. In order to avoid breaking
|
||||||
|
@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
num_filts // 4, 2, kernel_size=1, padding=0
|
num_filts // 4, 2, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
self.conv_classes_op = nn.Conv2d(
|
self.conv_classes_op = nn.Conv2d(
|
||||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
|
num_filts // 4,
|
||||||
|
self.num_classes + 1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.emb_dim > 0:
|
if self.emb_dim > 0:
|
||||||
|
@ -5,7 +5,10 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, computed_field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
|
from batdetect2.train.legacy.train_utils import (
|
||||||
|
get_genus_mapping,
|
||||||
|
get_short_class_names,
|
||||||
|
)
|
||||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256000
|
TARGET_SAMPLERATE_HZ = 256000
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Post-processing of the output of the model."""
|
"""Post-processing of the output of the model."""
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -1,16 +1,66 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.metrics import (
|
import pandas as pd
|
||||||
accuracy_score,
|
from sklearn.metrics import auc, roc_curve
|
||||||
auc,
|
from soundevent import data
|
||||||
balanced_accuracy_score,
|
from soundevent.evaluation import match_geometries
|
||||||
roc_curve,
|
|
||||||
)
|
from batdetect2.train.targets import build_encoder, get_class_names
|
||||||
|
|
||||||
|
|
||||||
|
def match_predictions_and_annotations(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
clip_prediction: data.ClipPrediction,
|
||||||
|
) -> List[data.Match]:
|
||||||
|
annotated_sound_events = [
|
||||||
|
sound_event_annotation
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_sound_events = [
|
||||||
|
sound_event_prediction
|
||||||
|
for sound_event_prediction in clip_prediction.sound_events
|
||||||
|
if sound_event_prediction.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
annotated_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in annotated_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in predicted_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for id1, id2, affinity in match_geometries(
|
||||||
|
annotated_geometries,
|
||||||
|
predicted_geometries,
|
||||||
|
):
|
||||||
|
target = annotated_sound_events[id1] if id1 is not None else None
|
||||||
|
source = predicted_sound_events[id2] if id2 is not None else None
|
||||||
|
matches.append(
|
||||||
|
data.Match(source=source, target=target, affinity=affinity)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def compute_error_auc(op_str, gt, pred, prob):
|
def compute_error_auc(op_str, gt, pred, prob):
|
||||||
|
|
||||||
# classification error
|
# classification error
|
||||||
pred_int = (pred > prob).astype(np.int)
|
pred_int = (pred > prob).astype(np.int32)
|
||||||
class_acc = (pred_int == gt).mean() * 100.0
|
class_acc = (pred_int == gt).mean() * 100.0
|
||||||
|
|
||||||
# ROC - area under curve
|
# ROC - area under curve
|
||||||
@ -25,7 +75,6 @@ def compute_error_auc(op_str, gt, pred, prob):
|
|||||||
|
|
||||||
|
|
||||||
def calc_average_precision(recall, precision):
|
def calc_average_precision(recall, precision):
|
||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
precision[np.isnan(precision)] = 0
|
||||||
recall[np.isnan(recall)] = 0
|
recall[np.isnan(recall)] = 0
|
||||||
|
|
||||||
@ -91,7 +140,6 @@ def compute_pre_rec(
|
|||||||
pred_class = []
|
pred_class = []
|
||||||
file_ids = []
|
file_ids = []
|
||||||
for pid, pp in enumerate(preds):
|
for pid, pp in enumerate(preds):
|
||||||
|
|
||||||
# filter predicted calls that are too near the start or end of the file
|
# filter predicted calls that are too near the start or end of the file
|
||||||
file_dur = gts[pid]["duration"]
|
file_dur = gts[pid]["duration"]
|
||||||
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
||||||
@ -141,7 +189,6 @@ def compute_pre_rec(
|
|||||||
gt_generic_class = []
|
gt_generic_class = []
|
||||||
num_positives = 0
|
num_positives = 0
|
||||||
for gg in gts:
|
for gg in gts:
|
||||||
|
|
||||||
# filter ground truth calls that are too near the start or end of the file
|
# filter ground truth calls that are too near the start or end of the file
|
||||||
file_dur = gg["duration"]
|
file_dur = gg["duration"]
|
||||||
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
||||||
@ -205,7 +252,6 @@ def compute_pre_rec(
|
|||||||
|
|
||||||
# valid detection that has not already been assigned
|
# valid detection that has not already been assigned
|
||||||
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
||||||
|
|
||||||
count_as_true_pos = True
|
count_as_true_pos = True
|
||||||
if eval_mode == "top_class" and (
|
if eval_mode == "top_class" and (
|
||||||
gt_class[gt_id][det_ind] != pred_class[ind]
|
gt_class[gt_id][det_ind] != pred_class[ind]
|
@ -12,15 +12,14 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
import batdetect2.train.evaluate as evl
|
import batdetect2.evaluate.legacy.evaluate_models as evl
|
||||||
import batdetect2.train.train_utils as tu
|
import batdetect2.train.legacy.train_utils as tu
|
||||||
import batdetect2.utils.detector_utils as du
|
import batdetect2.utils.detector_utils as du
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
|
|
||||||
def get_blank_annotation(ip_str):
|
def get_blank_annotation(ip_str):
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
res["class_name"] = ""
|
res["class_name"] = ""
|
||||||
res["duration"] = -1
|
res["duration"] = -1
|
||||||
@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names):
|
|||||||
|
|
||||||
|
|
||||||
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
||||||
|
|
||||||
res, ann = get_blank_annotation("Generated by Tadarida")
|
res, ann = get_blank_annotation("Generated by Tadarida")
|
||||||
|
|
||||||
# create the annotations in the correct format
|
# create the annotations in the correct format
|
||||||
@ -120,7 +118,6 @@ def load_sonobat_meta(
|
|||||||
class_names,
|
class_names,
|
||||||
only_accepted_species=True,
|
only_accepted_species=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
sp_dict = {}
|
sp_dict = {}
|
||||||
for ss in class_names:
|
for ss in class_names:
|
||||||
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
||||||
@ -182,7 +179,6 @@ def load_sonobat_meta(
|
|||||||
|
|
||||||
|
|
||||||
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||||
|
|
||||||
# create the annotations in the correct format
|
# create the annotations in the correct format
|
||||||
res, ann = get_blank_annotation("Generated by Sonobat")
|
res, ann = get_blank_annotation("Generated by Sonobat")
|
||||||
res_c = copy.deepcopy(res)
|
res_c = copy.deepcopy(res)
|
||||||
@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
|||||||
|
|
||||||
|
|
||||||
def bb_overlap(bb_g_in, bb_p_in):
|
def bb_overlap(bb_g_in, bb_p_in):
|
||||||
|
|
||||||
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
||||||
bb_g = [
|
bb_g = [
|
||||||
bb_g_in["start_time"],
|
bb_g_in["start_time"],
|
||||||
@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"op_dir",
|
"op_dir",
|
@ -8,10 +8,10 @@ import torch.utils.data
|
|||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
import batdetect2.detector.parameters as parameters
|
import batdetect2.detector.parameters as parameters
|
||||||
import batdetect2.train.audio_dataloader as adl
|
import batdetect2.train.legacy.audio_dataloader as adl
|
||||||
|
import batdetect2.train.legacy.train_model as tm
|
||||||
|
import batdetect2.train.legacy.train_utils as tu
|
||||||
import batdetect2.train.losses as losses
|
import batdetect2.train.losses as losses
|
||||||
import batdetect2.train.train_model as tm
|
|
||||||
import batdetect2.train.train_utils as tu
|
|
||||||
import batdetect2.utils.detector_utils as du
|
import batdetect2.utils.detector_utils as du
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
from batdetect2 import types
|
from batdetect2 import types
|
||||||
|
@ -1,11 +1,92 @@
|
|||||||
from batdetect2.models.feature_extractors import (
|
from enum import Enum
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.models.backbones import (
|
||||||
Net2DFast,
|
Net2DFast,
|
||||||
Net2DFastNoAttn,
|
Net2DFastNoAttn,
|
||||||
Net2DFastNoCoordConv,
|
Net2DFastNoCoordConv,
|
||||||
|
Net2DPlain,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
|
from batdetect2.models.typing import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BBoxHead",
|
||||||
|
"ClassifierHead",
|
||||||
|
"ModelConfig",
|
||||||
|
"ModelType",
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
|
"build_architecture",
|
||||||
|
"load_model_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
Net2DFast = "Net2DFast"
|
||||||
|
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||||
|
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||||
|
Net2DPlain = "Net2DPlain"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseConfig):
|
||||||
|
name: ModelType = ModelType.Net2DFast
|
||||||
|
input_height: int = 128
|
||||||
|
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||||
|
bottleneck_channels: int = 256
|
||||||
|
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_config(
|
||||||
|
path: PathLike, field: Optional[str] = None
|
||||||
|
) -> ModelConfig:
|
||||||
|
return load_config(path, schema=ModelConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def build_architecture(
|
||||||
|
config: Optional[ModelConfig] = None,
|
||||||
|
) -> BackboneModel:
|
||||||
|
config = config or ModelConfig()
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DFast:
|
||||||
|
return Net2DFast(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DFastNoAttn:
|
||||||
|
return Net2DFastNoAttn(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||||
|
return Net2DFastNoCoordConv(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == ModelType.Net2DPlain:
|
||||||
|
return Net2DPlain(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown model type: {config.name}")
|
||||||
|
185
batdetect2/models/backbones.py
Normal file
185
batdetect2/models/backbones.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from batdetect2.models.blocks import (
|
||||||
|
ConvBlock,
|
||||||
|
Decoder,
|
||||||
|
DownscalingLayer,
|
||||||
|
Encoder,
|
||||||
|
SelfAttention,
|
||||||
|
UpscalingLayer,
|
||||||
|
VerticalConv,
|
||||||
|
)
|
||||||
|
from batdetect2.models.typing import BackboneModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Net2DFast",
|
||||||
|
"Net2DFastNoAttn",
|
||||||
|
"Net2DFastNoCoordConv",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DPlain(BackboneModel):
|
||||||
|
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
|
||||||
|
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_height: int = 128,
|
||||||
|
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||||
|
bottleneck_channels: int = 256,
|
||||||
|
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||||
|
out_channels: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_height = input_height
|
||||||
|
self.encoder_channels = tuple(encoder_channels)
|
||||||
|
self.decoder_channels = tuple(decoder_channels)
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if len(encoder_channels) != len(decoder_channels):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatched encoder and decoder channel lists. "
|
||||||
|
f"The encoder has {len(encoder_channels)} channels "
|
||||||
|
f"(implying {len(encoder_channels) - 1} layers), "
|
||||||
|
f"while the decoder has {len(decoder_channels)} channels "
|
||||||
|
f"(implying {len(decoder_channels) - 1} layers). "
|
||||||
|
f"These lengths must be equal."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.divide_factor = 2 ** (len(encoder_channels) - 1)
|
||||||
|
if self.input_height % self.divide_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input height ({self.input_height}) must be divisible by "
|
||||||
|
f"the divide factor ({self.divide_factor}). "
|
||||||
|
f"This ensures proper upscaling after downscaling to recover "
|
||||||
|
f"the original input height."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
channels=encoder_channels,
|
||||||
|
input_height=self.input_height,
|
||||||
|
layer_type=self.downscaling_layer_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_same_1 = ConvBlock(
|
||||||
|
in_channels=encoder_channels[-1],
|
||||||
|
out_channels=bottleneck_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# bottleneck
|
||||||
|
self.conv_vert = VerticalConv(
|
||||||
|
in_channels=bottleneck_channels,
|
||||||
|
out_channels=bottleneck_channels,
|
||||||
|
input_height=self.input_height // (2**self.encoder.depth),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
channels=decoder_channels,
|
||||||
|
input_height=self.input_height,
|
||||||
|
layer_type=self.upscaling_layer_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_same_2 = ConvBlock(
|
||||||
|
in_channels=decoder_channels[-1],
|
||||||
|
out_channels=out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
residuals = self.encoder(spec)
|
||||||
|
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||||
|
|
||||||
|
# bottleneck
|
||||||
|
x = self.conv_vert(residuals[-1])
|
||||||
|
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
x = self.decoder(x, residuals=residuals)
|
||||||
|
|
||||||
|
# Restore original size
|
||||||
|
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||||
|
|
||||||
|
return self.conv_same_2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFast(Net2DPlain):
|
||||||
|
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||||
|
upscaling_layer_type = "ConvBlockUpF"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_height: int = 128,
|
||||||
|
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||||
|
bottleneck_channels: int = 256,
|
||||||
|
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||||
|
out_channels: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
input_height=input_height,
|
||||||
|
encoder_channels=encoder_channels,
|
||||||
|
bottleneck_channels=bottleneck_channels,
|
||||||
|
decoder_channels=decoder_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
residuals = self.encoder(spec)
|
||||||
|
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||||
|
|
||||||
|
# bottleneck
|
||||||
|
x = self.conv_vert(residuals[-1])
|
||||||
|
x = self.att(x)
|
||||||
|
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
x = self.decoder(x, residuals=residuals)
|
||||||
|
|
||||||
|
# Restore original size
|
||||||
|
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||||
|
|
||||||
|
return self.conv_same_2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFastNoAttn(Net2DPlain):
|
||||||
|
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||||
|
upscaling_layer_type = "ConvBlockUpF"
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFastNoCoordConv(Net2DFast):
|
||||||
|
downscaling_layer_type = "ConvBlockDownStandard"
|
||||||
|
upscaling_layer_type = "ConvBlockUpStandard"
|
||||||
|
|
||||||
|
|
||||||
|
def pad_adjust(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
factor: int = 32,
|
||||||
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
|
print(spec.shape)
|
||||||
|
h, w = spec.shape[2:]
|
||||||
|
h_pad = -h % factor
|
||||||
|
w_pad = -w % factor
|
||||||
|
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||||
|
|
||||||
|
|
||||||
|
def restore_pad(
|
||||||
|
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Restore original size
|
||||||
|
if h_pad > 0:
|
||||||
|
x = x[:, :, :-h_pad, :]
|
||||||
|
|
||||||
|
if w_pad > 0:
|
||||||
|
x = x[:, :, :, :-w_pad]
|
||||||
|
|
||||||
|
return x
|
@ -4,18 +4,32 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build
|
|||||||
complex neural network architectures.
|
complex neural network architectures.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
import sys
|
||||||
|
from typing import Iterable, List, Literal, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from itertools import pairwise
|
||||||
|
else:
|
||||||
|
|
||||||
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SelfAttention",
|
"ConvBlock",
|
||||||
"ConvBlockDownCoordF",
|
"ConvBlockDownCoordF",
|
||||||
"ConvBlockDownStandard",
|
"ConvBlockDownStandard",
|
||||||
"ConvBlockUpF",
|
"ConvBlockUpF",
|
||||||
"ConvBlockUpStandard",
|
"ConvBlockUpStandard",
|
||||||
|
"SelfAttention",
|
||||||
|
"VerticalConv",
|
||||||
|
"DownscalingLayer",
|
||||||
|
"UpscalingLayer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -25,16 +39,21 @@ class SelfAttention(nn.Module):
|
|||||||
This module implements self-attention mechanism.
|
This module implements self-attention mechanism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ip_dim: int, att_dim: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
attention_channels: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Note, does not encode position information (absolute or realtive)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = 1.0
|
self.temperature = temperature
|
||||||
self.att_dim = att_dim
|
self.att_dim = attention_channels
|
||||||
self.key_fun = nn.Linear(ip_dim, att_dim)
|
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.val_fun = nn.Linear(ip_dim, att_dim)
|
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.que_fun = nn.Linear(ip_dim, att_dim)
|
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = x.squeeze(2).permute(0, 2, 1)
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
@ -43,11 +62,11 @@ class SelfAttention(nn.Module):
|
|||||||
x, self.key_fun.weight.T
|
x, self.key_fun.weight.T
|
||||||
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
query = torch.matmul(
|
query = torch.matmul(
|
||||||
x, self.que_fun.weight.T
|
x, self.query_fun.weight.T
|
||||||
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
value = torch.matmul(
|
value = torch.matmul(
|
||||||
x, self.val_fun.weight.T
|
x, self.value_fun.weight.T
|
||||||
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
) + self.value_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||||
self.temperature * self.att_dim
|
self.temperature * self.att_dim
|
||||||
@ -63,6 +82,66 @@ class SelfAttention(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
pad_size: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=pad_size,
|
||||||
|
)
|
||||||
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.relu_(self.conv_bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class VerticalConv(nn.Module):
|
||||||
|
"""Convolutional layer over full height.
|
||||||
|
|
||||||
|
This layer applies a convolution that captures information across the
|
||||||
|
entire height of the input image. It uses a kernel with the same height as
|
||||||
|
the input, effectively condensing the vertical information into a single
|
||||||
|
output row.
|
||||||
|
|
||||||
|
More specifically:
|
||||||
|
|
||||||
|
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
||||||
|
input channels, H is the image height, and W is the image width.
|
||||||
|
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
||||||
|
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
||||||
|
convolution integrates information from all rows of the input.
|
||||||
|
|
||||||
|
This process effectively extracts features that span the full height of
|
||||||
|
the input image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(input_height, 1),
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.relu_(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownCoordF(nn.Module):
|
class ConvBlockDownCoordF(nn.Module):
|
||||||
"""Convolutional Block with Downsampling and Coord Feature.
|
"""Convolutional Block with Downsampling and Coord Feature.
|
||||||
|
|
||||||
@ -72,27 +151,27 @@ class ConvBlockDownCoordF(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn: int,
|
in_channels: int,
|
||||||
out_chn: int,
|
out_channels: int,
|
||||||
ip_height: int,
|
input_height: int,
|
||||||
k_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
stride: int = 1,
|
stride: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.coords = nn.Parameter(
|
self.coords = nn.Parameter(
|
||||||
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
|
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn + 1,
|
in_channels + 1,
|
||||||
out_chn,
|
out_channels,
|
||||||
kernel_size=k_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||||
@ -110,26 +189,28 @@ class ConvBlockDownStandard(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn,
|
in_channels: int,
|
||||||
out_chn,
|
out_channels: int,
|
||||||
k_size=3,
|
kernel_size: int = 3,
|
||||||
pad_size=1,
|
pad_size: int = 1,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
):
|
):
|
||||||
super(ConvBlockDownStandard, self).__init__()
|
super(ConvBlockDownStandard, self).__init__()
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn,
|
in_channels,
|
||||||
out_chn,
|
out_channels,
|
||||||
kernel_size=k_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
x = F.relu(self.conv_bn(x), inplace=True)
|
return F.relu(self.conv_bn(x), inplace=True)
|
||||||
return x
|
|
||||||
|
|
||||||
|
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpF(nn.Module):
|
class ConvBlockUpF(nn.Module):
|
||||||
@ -141,10 +222,10 @@ class ConvBlockUpF(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn: int,
|
in_channels: int,
|
||||||
out_chn: int,
|
out_channels: int,
|
||||||
ip_height: int,
|
input_height: int,
|
||||||
k_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
@ -154,15 +235,18 @@ class ConvBlockUpF(nn.Module):
|
|||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
self.coords = nn.Parameter(
|
self.coords = nn.Parameter(
|
||||||
torch.linspace(-1, 1, ip_height * up_scale[0])[
|
torch.linspace(-1, 1, input_height * up_scale[0])[
|
||||||
None, None, ..., None
|
None, None, ..., None
|
||||||
],
|
],
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
|
in_channels + 1,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
op = F.interpolate(
|
op = F.interpolate(
|
||||||
@ -189,9 +273,9 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn: int,
|
in_channels: int,
|
||||||
out_chn: int,
|
out_channels: int,
|
||||||
k_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
@ -200,9 +284,12 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn, out_chn, kernel_size=k_size, padding=pad_size
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
op = F.interpolate(
|
op = F.interpolate(
|
||||||
@ -217,3 +304,143 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.conv_bn(op), inplace=True)
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
||||||
|
|
||||||
|
|
||||||
|
def build_downscaling_layer(
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
layer_type: DownscalingLayer,
|
||||||
|
) -> nn.Module:
|
||||||
|
if layer_type == "ConvBlockDownStandard":
|
||||||
|
return ConvBlockDownStandard(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_type == "ConvBlockDownCoordF":
|
||||||
|
return ConvBlockDownCoordF(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid downscaling layer type {layer_type}. "
|
||||||
|
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: Sequence[int] = (1, 32, 62, 128),
|
||||||
|
input_height: int = 128,
|
||||||
|
layer_type: Literal[
|
||||||
|
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
||||||
|
] = "ConvBlockDownStandard",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.input_height = input_height
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
build_downscaling_layer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height // (2**layer_num),
|
||||||
|
layer_type=layer_type,
|
||||||
|
)
|
||||||
|
for layer_num, (in_channels, out_channels) in enumerate(
|
||||||
|
pairwise(channels)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.depth = len(self.layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
outputs.append(x)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def build_upscaling_layer(
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
layer_type: UpscalingLayer,
|
||||||
|
) -> nn.Module:
|
||||||
|
if layer_type == "ConvBlockUpStandard":
|
||||||
|
return ConvBlockUpStandard(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_type == "ConvBlockUpF":
|
||||||
|
return ConvBlockUpF(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid upscaling layer type {layer_type}. "
|
||||||
|
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: Sequence[int] = (256, 62, 32, 32),
|
||||||
|
input_height: int = 128,
|
||||||
|
layer_type: Literal[
|
||||||
|
"ConvBlockUpStandard", "ConvBlockUpF"
|
||||||
|
] = "ConvBlockUpStandard",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.input_height = input_height
|
||||||
|
self.depth = len(self.channels) - 1
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
build_upscaling_layer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height
|
||||||
|
// (2 ** (self.depth - layer_num)),
|
||||||
|
layer_type=layer_type,
|
||||||
|
)
|
||||||
|
for layer_num, (in_channels, out_channels) in enumerate(
|
||||||
|
pairwise(channels)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residuals: List[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if len(residuals) != len(self.layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect number of residuals provided. "
|
||||||
|
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||||
|
f"but got {len(residuals)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer, res in zip(self.layers, residuals[::-1]):
|
||||||
|
x = layer(x + res)
|
||||||
|
|
||||||
|
return x
|
||||||
|
15
batdetect2/models/decoder.py
Normal file
15
batdetect2/models/decoder.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Iterable, List, Literal, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from itertools import pairwise
|
||||||
|
else:
|
||||||
|
|
||||||
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
|
yield x, y
|
@ -1,139 +0,0 @@
|
|||||||
from typing import Type
|
|
||||||
|
|
||||||
import pytorch_lightning as L
|
|
||||||
import torch
|
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
|
||||||
from torch import nn, optim
|
|
||||||
|
|
||||||
from batdetect2.data.preprocessing import (
|
|
||||||
preprocess_audio_clip,
|
|
||||||
PreprocessingConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.data.labels import ClassMapper
|
|
||||||
from batdetect2.models.feature_extractors import Net2DFast
|
|
||||||
from batdetect2.models.post_process import (
|
|
||||||
PostprocessConfig,
|
|
||||||
postprocess_model_outputs,
|
|
||||||
)
|
|
||||||
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
|
|
||||||
from batdetect2.train import losses
|
|
||||||
from batdetect2.train.dataset import TrainExample
|
|
||||||
|
|
||||||
|
|
||||||
class DetectorModel(L.LightningModule):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_mapper: ClassMapper,
|
|
||||||
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
|
|
||||||
learning_rate: float = 1e-3,
|
|
||||||
input_height: int = 128,
|
|
||||||
num_features: int = 32,
|
|
||||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
|
||||||
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.save_hyperparameters()
|
|
||||||
|
|
||||||
self.preprocessing_config = preprocessing_config
|
|
||||||
self.postprocessing_config = postprocessing_config
|
|
||||||
self.class_mapper = class_mapper
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
self.input_height = input_height
|
|
||||||
self.num_features = num_features
|
|
||||||
self.num_classes = class_mapper.num_classes
|
|
||||||
|
|
||||||
self.feature_extractor = feature_extractor_class(
|
|
||||||
input_height=input_height,
|
|
||||||
num_features=num_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.classifier = nn.Conv2d(
|
|
||||||
self.feature_extractor.num_features // 4,
|
|
||||||
self.num_classes + 1,
|
|
||||||
kernel_size=1,
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bbox = nn.Conv2d(
|
|
||||||
self.feature_extractor.num_features // 4,
|
|
||||||
2,
|
|
||||||
kernel_size=1,
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
|
||||||
features = self.feature_extractor(spec)
|
|
||||||
classification_logits = self.classifier(features)
|
|
||||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
|
||||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
|
||||||
return ModelOutput(
|
|
||||||
detection_probs=detection_probs,
|
|
||||||
size_preds=self.bbox(features),
|
|
||||||
class_probs=classification_probs[:, :-1],
|
|
||||||
features=features,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
|
||||||
return preprocess_audio_clip(
|
|
||||||
clip,
|
|
||||||
config=self.preprocessing_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
|
|
||||||
spectrogram = self.compute_spectrogram(clip)
|
|
||||||
return self.feature_extractor(
|
|
||||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
|
|
||||||
spectrogram = self.compute_spectrogram(clip)
|
|
||||||
spec_tensor = (
|
|
||||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
|
||||||
)
|
|
||||||
outputs = self(spec_tensor)
|
|
||||||
return postprocess_model_outputs(
|
|
||||||
outputs,
|
|
||||||
[clip],
|
|
||||||
class_mapper=self.class_mapper,
|
|
||||||
config=self.postprocessing_config,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
outputs: ModelOutput,
|
|
||||||
batch: TrainExample,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
detection_loss = losses.focal_loss(
|
|
||||||
outputs.detection_probs,
|
|
||||||
batch.detection_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
size_loss = losses.bbox_size_loss(
|
|
||||||
outputs.size_preds,
|
|
||||||
batch.size_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
|
||||||
classification_loss = losses.focal_loss(
|
|
||||||
outputs.class_probs,
|
|
||||||
batch.class_heatmap,
|
|
||||||
valid_mask=valid_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
return detection_loss + size_loss + classification_loss
|
|
||||||
|
|
||||||
def training_step( # type: ignore
|
|
||||||
self,
|
|
||||||
batch: TrainExample,
|
|
||||||
):
|
|
||||||
outputs = self.forward(batch.spec)
|
|
||||||
loss = self.compute_loss(outputs, batch)
|
|
||||||
self.log("train_loss", loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
|
||||||
return [optimizer], [scheduler]
|
|
15
batdetect2/models/encoder.py
Normal file
15
batdetect2/models/encoder.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Iterable, List, Literal, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from itertools import pairwise
|
||||||
|
else:
|
||||||
|
|
||||||
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
|
yield x, y
|
@ -1,319 +0,0 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.fft
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from batdetect2.models.blocks import (
|
|
||||||
ConvBlockDownCoordF,
|
|
||||||
ConvBlockDownStandard,
|
|
||||||
ConvBlockUpF,
|
|
||||||
ConvBlockUpStandard,
|
|
||||||
SelfAttention,
|
|
||||||
)
|
|
||||||
from batdetect2.models.typing import FeatureExtractorModel
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Net2DFast",
|
|
||||||
"Net2DFastNoAttn",
|
|
||||||
"Net2DFastNoCoordConv",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFast(FeatureExtractorModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_features: int,
|
|
||||||
input_height: int = 128,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = num_features
|
|
||||||
self.input_height = input_height
|
|
||||||
self.bottleneck_height = self.input_height // 32
|
|
||||||
|
|
||||||
# encoder
|
|
||||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
|
||||||
1,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 2,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features,
|
|
||||||
self.input_height // 4,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
|
||||||
self.num_features,
|
|
||||||
self.num_features * 2,
|
|
||||||
3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
# bottleneck
|
|
||||||
self.conv_1d = nn.Conv2d(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features * 2,
|
|
||||||
(self.input_height // 8, 1),
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self.conv_up_2 = ConvBlockUpF(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 8,
|
|
||||||
)
|
|
||||||
self.conv_up_3 = ConvBlockUpF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 4,
|
|
||||||
)
|
|
||||||
self.conv_up_4 = ConvBlockUpF(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
|
||||||
|
|
||||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
|
||||||
h, w = spec.shape[2:]
|
|
||||||
h_pad = (32 - h % 32) % 32
|
|
||||||
w_pad = (32 - w % 32) % 32
|
|
||||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
# encoder
|
|
||||||
spec, h_pad, w_pad = self.pad_adjust(spec)
|
|
||||||
|
|
||||||
x1 = self.conv_dn_0(spec)
|
|
||||||
x2 = self.conv_dn_1(x1)
|
|
||||||
x3 = self.conv_dn_2(x2)
|
|
||||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
|
||||||
|
|
||||||
# bottleneck
|
|
||||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
|
||||||
x = self.att(x)
|
|
||||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
x = self.conv_up_2(x + x3)
|
|
||||||
x = self.conv_up_3(x + x2)
|
|
||||||
x = self.conv_up_4(x + x1)
|
|
||||||
|
|
||||||
# Restore original size
|
|
||||||
if h_pad > 0:
|
|
||||||
x = x[:, :, :-h_pad, :]
|
|
||||||
|
|
||||||
if w_pad > 0:
|
|
||||||
x = x[:, :, :, :-w_pad]
|
|
||||||
|
|
||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoAttn(FeatureExtractorModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_features: int,
|
|
||||||
input_height: int = 128,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_features = num_features
|
|
||||||
self.input_height = input_height
|
|
||||||
self.bottleneck_height = self.input_height // 32
|
|
||||||
|
|
||||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
|
||||||
1,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 2,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features,
|
|
||||||
self.input_height // 4,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
|
||||||
self.num_features,
|
|
||||||
self.num_features * 2,
|
|
||||||
3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_1d = nn.Conv2d(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features * 2,
|
|
||||||
(self.input_height // 8, 1),
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_up_2 = ConvBlockUpF(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 8,
|
|
||||||
)
|
|
||||||
self.conv_up_3 = ConvBlockUpF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 4,
|
|
||||||
)
|
|
||||||
self.conv_up_4 = ConvBlockUpF(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
x1 = self.conv_dn_0(spec)
|
|
||||||
x2 = self.conv_dn_1(x1)
|
|
||||||
x3 = self.conv_dn_2(x2)
|
|
||||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
|
||||||
|
|
||||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
|
||||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
|
||||||
|
|
||||||
x = self.conv_up_2(x + x3)
|
|
||||||
x = self.conv_up_3(x + x2)
|
|
||||||
x = self.conv_up_4(x + x1)
|
|
||||||
|
|
||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoCoordConv(FeatureExtractorModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_features: int,
|
|
||||||
input_height: int = 128,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_features = num_features
|
|
||||||
self.input_height = input_height
|
|
||||||
self.bottleneck_height = self.input_height // 32
|
|
||||||
|
|
||||||
self.conv_dn_0 = ConvBlockDownStandard(
|
|
||||||
1,
|
|
||||||
self.num_features // 4,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_1 = ConvBlockDownStandard(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 2,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_2 = ConvBlockDownStandard(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features,
|
|
||||||
k_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
|
||||||
self.num_features,
|
|
||||||
self.num_features * 2,
|
|
||||||
3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_1d = nn.Conv2d(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features * 2,
|
|
||||||
(self.input_height // 8, 1),
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_up_2 = ConvBlockUpStandard(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 8,
|
|
||||||
)
|
|
||||||
self.conv_up_3 = ConvBlockUpStandard(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 4,
|
|
||||||
)
|
|
||||||
self.conv_up_4 = ConvBlockUpStandard(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
x1 = self.conv_dn_0(spec)
|
|
||||||
x2 = self.conv_dn_1(x1)
|
|
||||||
x3 = self.conv_dn_2(x2)
|
|
||||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
|
||||||
|
|
||||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
|
||||||
x = self.att(x)
|
|
||||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
|
||||||
|
|
||||||
x = self.conv_up_2(x + x3)
|
|
||||||
x = self.conv_up_3(x + x2)
|
|
||||||
x = self.conv_up_4(x + x1)
|
|
||||||
|
|
||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
|
51
batdetect2/models/heads.py
Normal file
51
batdetect2/models/heads.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
__all__ = ["ClassifierHead"]
|
||||||
|
|
||||||
|
|
||||||
|
class Output(NamedTuple):
|
||||||
|
detection: torch.Tensor
|
||||||
|
classification: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierHead(nn.Module):
|
||||||
|
def __init__(self, num_classes: int, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.classifier = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
# Add one to account for the background class
|
||||||
|
self.num_classes + 1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> Output:
|
||||||
|
logits = self.classifier(features)
|
||||||
|
probs = torch.softmax(logits, dim=1)
|
||||||
|
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
|
||||||
|
return Output(
|
||||||
|
detection=detection_probs,
|
||||||
|
classification=probs[:, :-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxHead(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.bbox = nn.Conv2d(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
out_channels=2,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.bbox(features)
|
@ -1,12 +1,12 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"FeatureExtractorModel",
|
"BackboneModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -41,16 +41,31 @@ class ModelOutput(NamedTuple):
|
|||||||
"""Tensor with intermediate features."""
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractorModel(ABC, nn.Module):
|
class BackboneModel(ABC, nn.Module):
|
||||||
input_height: int
|
input_height: int
|
||||||
"""Height of the input spectrogram."""
|
"""Height of the input spectrogram."""
|
||||||
|
|
||||||
num_features: int
|
encoder_channels: Tuple[int, ...]
|
||||||
"""Dimension of the feature tensor."""
|
"""Tuple specifying the number of channels for each convolutional layer
|
||||||
|
in the encoder. The length of this tuple determines the number of
|
||||||
|
encoder layers."""
|
||||||
|
|
||||||
|
decoder_channels: Tuple[int, ...]
|
||||||
|
"""Tuple specifying the number of channels for each convolutional layer
|
||||||
|
in the decoder. The length of this tuple determines the number of
|
||||||
|
decoder layers."""
|
||||||
|
|
||||||
|
bottleneck_channels: int
|
||||||
|
"""Number of channels in the bottleneck layer, which connects the
|
||||||
|
encoder and decoder."""
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
"""Number of channels in the final output feature map produced by the
|
||||||
|
backbone model."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
"""Forward pass of the encoder model."""
|
"""Forward pass of the model."""
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, nn.Module):
|
class DetectionModel(ABC, nn.Module):
|
||||||
|
181
batdetect2/modules.py
Normal file
181
batdetect2/modules.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import lightning as L
|
||||||
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from torch.optim.adam import Adam
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
|
||||||
|
from batdetect2.models import (
|
||||||
|
BBoxHead,
|
||||||
|
ClassifierHead,
|
||||||
|
ModelConfig,
|
||||||
|
build_architecture,
|
||||||
|
)
|
||||||
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
from batdetect2.post_process import (
|
||||||
|
PostprocessConfig,
|
||||||
|
postprocess_model_outputs,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||||
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
|
from batdetect2.train.losses import compute_loss
|
||||||
|
from batdetect2.train.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_decoder,
|
||||||
|
build_encoder,
|
||||||
|
get_class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DetectorModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleConfig(BaseConfig):
|
||||||
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
architecture: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocessing: PostprocessConfig = Field(
|
||||||
|
default_factory=PostprocessConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorModel(L.LightningModule):
|
||||||
|
config: ModuleConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[ModuleConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config or ModuleConfig()
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
self.backbone = build_architecture(self.config.architecture)
|
||||||
|
|
||||||
|
self.classifier = ClassifierHead(
|
||||||
|
num_classes=len(self.config.targets.classes),
|
||||||
|
in_channels=self.backbone.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
||||||
|
|
||||||
|
conf = self.config.train.loss.classification
|
||||||
|
self.class_weights = (
|
||||||
|
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training targets
|
||||||
|
self.class_names = get_class_names(self.config.targets.classes)
|
||||||
|
self.encoder = build_encoder(
|
||||||
|
self.config.targets.classes,
|
||||||
|
replacement_rules=self.config.targets.replace,
|
||||||
|
)
|
||||||
|
self.decoder = build_decoder(self.config.targets.classes)
|
||||||
|
|
||||||
|
self.validation_predictions = []
|
||||||
|
|
||||||
|
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||||
|
features = self.backbone(spec)
|
||||||
|
detection_probs, classification_probs = self.classifier(features)
|
||||||
|
size_preds = self.bbox(features)
|
||||||
|
return ModelOutput(
|
||||||
|
detection_probs=detection_probs,
|
||||||
|
size_preds=size_preds,
|
||||||
|
class_probs=classification_probs,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
def training_step(self, batch: TrainExample):
|
||||||
|
outputs = self.forward(batch.spec)
|
||||||
|
losses = compute_loss(
|
||||||
|
batch,
|
||||||
|
outputs,
|
||||||
|
conf=self.config.train.loss,
|
||||||
|
class_weights=self.class_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||||
|
self.log("train/loss/detection", losses.total, logger=True)
|
||||||
|
self.log("train/loss/size", losses.total, logger=True)
|
||||||
|
self.log("train/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
|
return losses.total
|
||||||
|
|
||||||
|
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||||
|
outputs = self.forward(batch.spec)
|
||||||
|
|
||||||
|
losses = compute_loss(
|
||||||
|
batch,
|
||||||
|
outputs,
|
||||||
|
conf=self.config.train.loss,
|
||||||
|
class_weights=self.class_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||||
|
self.log("val/loss/detection", losses.total, logger=True)
|
||||||
|
self.log("val/loss/size", losses.total, logger=True)
|
||||||
|
self.log("val/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
|
dataloaders = self.trainer.val_dataloaders
|
||||||
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
dataset = dataloaders.dataset
|
||||||
|
assert isinstance(dataset, LabeledDataset)
|
||||||
|
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||||
|
|
||||||
|
clip_prediction = postprocess_model_outputs(
|
||||||
|
outputs,
|
||||||
|
clips=[clip_annotation.clip],
|
||||||
|
classes=self.class_names,
|
||||||
|
decoder=self.decoder,
|
||||||
|
config=self.config.postprocessing,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
matches = match_predictions_and_annotations(
|
||||||
|
clip_annotation,
|
||||||
|
clip_prediction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validation_predictions.extend(matches)
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self) -> None:
|
||||||
|
self.validation_predictions.clear()
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
conf = self.config.train.optimizer
|
||||||
|
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||||
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
def process_clip(
|
||||||
|
self,
|
||||||
|
clip: data.Clip,
|
||||||
|
audio_dir: Optional[Path] = None,
|
||||||
|
) -> data.ClipPrediction:
|
||||||
|
spec = preprocess_audio_clip(
|
||||||
|
clip,
|
||||||
|
config=self.config.preprocessing,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||||
|
outputs = self.forward(tensor)
|
||||||
|
return postprocess_model_outputs(
|
||||||
|
outputs,
|
||||||
|
clips=[clip],
|
||||||
|
classes=self.class_names,
|
||||||
|
decoder=self.decoder,
|
||||||
|
config=self.config.postprocessing,
|
||||||
|
)[0]
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import matplotlib.ticker as tick
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import axes, patches
|
from matplotlib import axes, patches
|
||||||
import matplotlib.ticker as tick
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
@ -102,7 +102,6 @@ def spectrogram(
|
|||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_with_detections(
|
def spectrogram_with_detections(
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
spec: Union[torch.Tensor, np.ndarray],
|
||||||
dets: List[Annotation],
|
dets: List[Annotation],
|
||||||
@ -231,11 +230,11 @@ def detection(
|
|||||||
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
|
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
|
||||||
to None. If `ax` is None, this will be used to create a new figure
|
to None. If `ax` is None, this will be used to create a new figure
|
||||||
of the given size.
|
of the given size.
|
||||||
linewidth (float, optional): Line width of the detection.
|
linewidth (float, optional): Line width of the detection.
|
||||||
Defaults to 1.
|
Defaults to 1.
|
||||||
edgecolor (str, optional): Edge color of the detection.
|
edgecolor (str, optional): Edge color of the detection.
|
||||||
Defaults to "w", i.e. white.
|
Defaults to "w", i.e. white.
|
||||||
facecolor (str, optional): Face color of the detection.
|
facecolor (str, optional): Face color of the detection.
|
||||||
Defaults to "none", i.e. transparent.
|
Defaults to "none", i.e. transparent.
|
||||||
with_name (bool, optional): Whether to plot the name of the
|
with_name (bool, optional): Whether to plot the name of the
|
||||||
predicted class next to the detection. Defaults to True.
|
predicted class next to the detection. Defaults to True.
|
||||||
|
@ -17,6 +17,6 @@ def create_ax(
|
|||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||||
|
|
||||||
return ax # type: ignore
|
return ax # type: ignore
|
||||||
|
@ -1,19 +1,20 @@
|
|||||||
"""Module for postprocessing model outputs."""
|
"""Module for postprocessing model outputs."""
|
||||||
|
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.data.labels import ClassMapper
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"postprocess_model_outputs",
|
|
||||||
"PostprocessConfig",
|
"PostprocessConfig",
|
||||||
|
"load_postprocess_config",
|
||||||
|
"postprocess_model_outputs",
|
||||||
]
|
]
|
||||||
|
|
||||||
NMS_KERNEL_SIZE = 9
|
NMS_KERNEL_SIZE = 9
|
||||||
@ -21,7 +22,7 @@ DETECTION_THRESHOLD = 0.01
|
|||||||
TOP_K_PER_SEC = 200
|
TOP_K_PER_SEC = 200
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseModel):
|
class PostprocessConfig(BaseConfig):
|
||||||
"""Configuration for postprocessing model outputs."""
|
"""Configuration for postprocessing model outputs."""
|
||||||
|
|
||||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||||
@ -31,14 +32,29 @@ class PostprocessConfig(BaseModel):
|
|||||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
TagFunction = Callable[[int], List[data.Tag]]
|
def load_postprocess_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PostprocessConfig:
|
||||||
|
return load_config(path, schema=PostprocessConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
class RawPrediction(NamedTuple):
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
low_freq: float
|
||||||
|
high_freq: float
|
||||||
|
detection_score: float
|
||||||
|
class_scores: Dict[str, float]
|
||||||
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
def postprocess_model_outputs(
|
def postprocess_model_outputs(
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
clips: List[data.Clip],
|
clips: List[data.Clip],
|
||||||
class_mapper: ClassMapper,
|
classes: List[str],
|
||||||
config: PostprocessConfig,
|
decoder: Callable[[str], List[data.Tag]],
|
||||||
|
config: Optional[PostprocessConfig] = None,
|
||||||
) -> List[data.ClipPrediction]:
|
) -> List[data.ClipPrediction]:
|
||||||
"""Postprocesses model outputs to generate clip predictions.
|
"""Postprocesses model outputs to generate clip predictions.
|
||||||
|
|
||||||
@ -68,6 +84,9 @@ def postprocess_model_outputs(
|
|||||||
ValueError
|
ValueError
|
||||||
If the number of predictions does not match the number of clips.
|
If the number of predictions does not match the number of clips.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config = config or PostprocessConfig()
|
||||||
|
|
||||||
num_predictions = len(outputs.detection_probs)
|
num_predictions = len(outputs.detection_probs)
|
||||||
|
|
||||||
if num_predictions == 0:
|
if num_predictions == 0:
|
||||||
@ -108,7 +127,8 @@ def postprocess_model_outputs(
|
|||||||
size_preds,
|
size_preds,
|
||||||
class_probs,
|
class_probs,
|
||||||
features,
|
features,
|
||||||
class_mapper=class_mapper,
|
classes=classes,
|
||||||
|
decoder=decoder,
|
||||||
min_freq=config.min_freq,
|
min_freq=config.min_freq,
|
||||||
max_freq=config.max_freq,
|
max_freq=config.max_freq,
|
||||||
detection_threshold=config.detection_threshold,
|
detection_threshold=config.detection_threshold,
|
||||||
@ -124,6 +144,82 @@ def postprocess_model_outputs(
|
|||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def compute_predictions_from_outputs(
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
scores: torch.Tensor,
|
||||||
|
y_pos: torch.Tensor,
|
||||||
|
x_pos: torch.Tensor,
|
||||||
|
size_preds: torch.Tensor,
|
||||||
|
class_probs: torch.Tensor,
|
||||||
|
features: torch.Tensor,
|
||||||
|
classes: List[str],
|
||||||
|
min_freq: int = 10000,
|
||||||
|
max_freq: int = 120000,
|
||||||
|
detection_threshold: float = DETECTION_THRESHOLD,
|
||||||
|
) -> List[RawPrediction]:
|
||||||
|
_, freq_bins, time_bins = size_preds.shape
|
||||||
|
|
||||||
|
sorted_indices = torch.argsort(x_pos)
|
||||||
|
valid_indices = sorted_indices[
|
||||||
|
scores[sorted_indices] > detection_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = scores[valid_indices]
|
||||||
|
x_pos = x_pos[valid_indices]
|
||||||
|
y_pos = y_pos[valid_indices]
|
||||||
|
|
||||||
|
predictions: List[RawPrediction] = []
|
||||||
|
for score, x, y in zip(scores, x_pos, y_pos):
|
||||||
|
width, height = size_preds[:, y, x]
|
||||||
|
class_prob = class_probs[:, y, x].detach().numpy()
|
||||||
|
feats = features[:, y, x].detach().numpy()
|
||||||
|
|
||||||
|
start_time = np.interp(
|
||||||
|
x.item(),
|
||||||
|
[0, time_bins],
|
||||||
|
[start, end],
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = np.interp(
|
||||||
|
x.item() + width.item(),
|
||||||
|
[0, time_bins],
|
||||||
|
[start, end],
|
||||||
|
)
|
||||||
|
|
||||||
|
low_freq = np.interp(
|
||||||
|
y.item(),
|
||||||
|
[0, freq_bins],
|
||||||
|
[max_freq, min_freq],
|
||||||
|
)
|
||||||
|
|
||||||
|
high_freq = np.interp(
|
||||||
|
y.item() - height.item(),
|
||||||
|
[0, freq_bins],
|
||||||
|
[max_freq, min_freq],
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||||
|
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||||
|
|
||||||
|
predictions.append(
|
||||||
|
RawPrediction(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
detection_score=score.item(),
|
||||||
|
features=feats,
|
||||||
|
class_scores={
|
||||||
|
class_name: prob
|
||||||
|
for class_name, prob in zip(classes, class_prob)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def compute_sound_events_from_outputs(
|
def compute_sound_events_from_outputs(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
scores: torch.Tensor,
|
scores: torch.Tensor,
|
||||||
@ -132,7 +228,8 @@ def compute_sound_events_from_outputs(
|
|||||||
size_preds: torch.Tensor,
|
size_preds: torch.Tensor,
|
||||||
class_probs: torch.Tensor,
|
class_probs: torch.Tensor,
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
class_mapper: ClassMapper,
|
classes: List[str],
|
||||||
|
decoder: Callable[[str], List[data.Tag]],
|
||||||
min_freq: int = 10000,
|
min_freq: int = 10000,
|
||||||
max_freq: int = 120000,
|
max_freq: int = 120000,
|
||||||
detection_threshold: float = DETECTION_THRESHOLD,
|
detection_threshold: float = DETECTION_THRESHOLD,
|
||||||
@ -181,12 +278,13 @@ def compute_sound_events_from_outputs(
|
|||||||
predicted_tags: List[data.PredictedTag] = []
|
predicted_tags: List[data.PredictedTag] = []
|
||||||
|
|
||||||
for label_id, class_score in enumerate(class_prob):
|
for label_id, class_score in enumerate(class_prob):
|
||||||
corresponding_tags = class_mapper.inverse_transform(label_id)
|
class_name = classes[label_id]
|
||||||
|
corresponding_tags = decoder(class_name)
|
||||||
predicted_tags.extend(
|
predicted_tags.extend(
|
||||||
[
|
[
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
tag=tag,
|
tag=tag,
|
||||||
score=class_score.item(),
|
score=max(min(class_score.item(), 1), 0),
|
||||||
)
|
)
|
||||||
for tag in corresponding_tags
|
for tag in corresponding_tags
|
||||||
]
|
]
|
||||||
@ -207,7 +305,7 @@ def compute_sound_events_from_outputs(
|
|||||||
),
|
),
|
||||||
features=[
|
features=[
|
||||||
data.Feature(
|
data.Feature(
|
||||||
name=f"batdetect2_{i}",
|
term=data.term_from_key(f"batdetect2_{i}"),
|
||||||
value=value.item(),
|
value=value.item(),
|
||||||
)
|
)
|
||||||
for i, value in enumerate(feature)
|
for i, value in enumerate(feature)
|
||||||
@ -217,7 +315,7 @@ def compute_sound_events_from_outputs(
|
|||||||
predictions.append(
|
predictions.append(
|
||||||
data.SoundEventPrediction(
|
data.SoundEventPrediction(
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
score=score.item(),
|
score=max(min(score.item(), 1), 0),
|
||||||
tags=predicted_tags,
|
tags=predicted_tags,
|
||||||
)
|
)
|
||||||
)
|
)
|
68
batdetect2/preprocess/__init__.py
Normal file
68
batdetect2/preprocess/__init__.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""Module containing functions for preprocessing audio clips."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.preprocess.audio import (
|
||||||
|
AudioConfig,
|
||||||
|
ResampleConfig,
|
||||||
|
load_clip_audio,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.config import (
|
||||||
|
PreprocessingConfig,
|
||||||
|
load_preprocessing_config,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
AmplitudeScaleConfig,
|
||||||
|
FrequencyConfig,
|
||||||
|
LogScaleConfig,
|
||||||
|
PcenScaleConfig,
|
||||||
|
Scales,
|
||||||
|
SpecSizeConfig,
|
||||||
|
SpectrogramConfig,
|
||||||
|
STFTConfig,
|
||||||
|
compute_spectrogram,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AmplitudeScaleConfig",
|
||||||
|
"AudioConfig",
|
||||||
|
"FrequencyConfig",
|
||||||
|
"LogScaleConfig",
|
||||||
|
"PcenScaleConfig",
|
||||||
|
"PreprocessingConfig",
|
||||||
|
"ResampleConfig",
|
||||||
|
"STFTConfig",
|
||||||
|
"Scales",
|
||||||
|
"SpecSizeConfig",
|
||||||
|
"SpectrogramConfig",
|
||||||
|
"load_preprocessing_config",
|
||||||
|
"preprocess_audio_clip",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio_clip(
|
||||||
|
clip: data.Clip,
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Preprocesses audio clip to generate spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip
|
||||||
|
The audio clip to preprocess.
|
||||||
|
config
|
||||||
|
Configuration for preprocessing.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
Preprocessed spectrogram.
|
||||||
|
|
||||||
|
"""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
|
||||||
|
return compute_spectrogram(wav, config=config.spectrogram)
|
61
batdetect2/preprocess/arrays.py
Normal file
61
batdetect2/preprocess/arrays.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def extend_width(
|
||||||
|
array: np.ndarray,
|
||||||
|
extra: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
pad,
|
||||||
|
mode="constant",
|
||||||
|
constant_values=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_width_divisible(
|
||||||
|
array: np.ndarray,
|
||||||
|
factor: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
width = array.shape[axis]
|
||||||
|
|
||||||
|
if width % factor == 0:
|
||||||
|
return array
|
||||||
|
|
||||||
|
extra = (-width) % factor
|
||||||
|
return extend_width(array, extra, axis=axis, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_width(
|
||||||
|
array: np.ndarray,
|
||||||
|
width: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
current_width = array.shape[axis]
|
||||||
|
|
||||||
|
if current_width == width:
|
||||||
|
return array
|
||||||
|
|
||||||
|
if current_width < width:
|
||||||
|
return extend_width(
|
||||||
|
array,
|
||||||
|
extra=width - current_width,
|
||||||
|
axis=axis,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
slices = [
|
||||||
|
slice(None, None) if index != axis else slice(None, width)
|
||||||
|
for index in range(dims)
|
||||||
|
]
|
||||||
|
return array[tuple(slices)]
|
199
batdetect2/preprocess/audio.py
Normal file
199
batdetect2/preprocess/audio.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from numpy.typing import DTypeLike
|
||||||
|
from pydantic import Field
|
||||||
|
from scipy.signal import resample, resample_poly
|
||||||
|
from soundevent import arrays, audio, data
|
||||||
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
TARGET_SAMPLERATE_HZ = 256_000
|
||||||
|
SCALE_RAW_AUDIO = False
|
||||||
|
DEFAULT_DURATION = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleConfig(BaseConfig):
|
||||||
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
mode: str = "poly"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioConfig(BaseConfig):
|
||||||
|
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||||
|
scale: bool = SCALE_RAW_AUDIO
|
||||||
|
center: bool = True
|
||||||
|
duration: Optional[float] = DEFAULT_DURATION
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_audio(
|
||||||
|
path: data.PathLike,
|
||||||
|
config: Optional[AudioConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
recording = data.Recording.from_file(path)
|
||||||
|
return load_recording_audio(
|
||||||
|
recording,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_recording_audio(
|
||||||
|
recording: data.Recording,
|
||||||
|
config: Optional[AudioConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
return load_clip_audio(
|
||||||
|
clip,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip_audio(
|
||||||
|
clip: data.Clip,
|
||||||
|
config: Optional[AudioConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
config = config or AudioConfig()
|
||||||
|
|
||||||
|
wav = (
|
||||||
|
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.duration is not None:
|
||||||
|
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||||
|
|
||||||
|
if config.resample:
|
||||||
|
wav = resample_audio(
|
||||||
|
wav,
|
||||||
|
samplerate=config.resample.samplerate,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.center:
|
||||||
|
wav = ops.center(wav)
|
||||||
|
|
||||||
|
if config.scale:
|
||||||
|
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||||
|
|
||||||
|
return wav.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_audio_duration(
|
||||||
|
wave: xr.DataArray,
|
||||||
|
duration: float,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
|
current_duration = end_time - start_time
|
||||||
|
|
||||||
|
if current_duration == duration:
|
||||||
|
return wave
|
||||||
|
|
||||||
|
if current_duration > duration:
|
||||||
|
return arrays.crop_dim(
|
||||||
|
wave,
|
||||||
|
dim="time",
|
||||||
|
start=start_time,
|
||||||
|
stop=start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
return arrays.extend_dim(
|
||||||
|
wave,
|
||||||
|
dim="time",
|
||||||
|
start=start_time,
|
||||||
|
stop=start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio(
|
||||||
|
wav: xr.DataArray,
|
||||||
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
mode: str = "poly",
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if "time" not in wav.dims:
|
||||||
|
raise ValueError("Audio must have a time dimension")
|
||||||
|
|
||||||
|
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||||
|
step = arrays.get_dim_step(wav, dim="time")
|
||||||
|
original_samplerate = int(1 / step)
|
||||||
|
|
||||||
|
if original_samplerate == samplerate:
|
||||||
|
return wav.astype(dtype)
|
||||||
|
|
||||||
|
if mode == "poly":
|
||||||
|
resampled = resample_audio_poly(
|
||||||
|
wav,
|
||||||
|
sr_orig=original_samplerate,
|
||||||
|
sr_new=samplerate,
|
||||||
|
axis=time_axis,
|
||||||
|
)
|
||||||
|
elif mode == "fourier":
|
||||||
|
resampled = resample_audio_fourier(
|
||||||
|
wav,
|
||||||
|
sr_orig=original_samplerate,
|
||||||
|
sr_new=samplerate,
|
||||||
|
axis=time_axis,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
|
||||||
|
|
||||||
|
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||||
|
times = np.linspace(
|
||||||
|
start,
|
||||||
|
stop + step,
|
||||||
|
len(resampled),
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=resampled.astype(dtype),
|
||||||
|
dims=wav.dims,
|
||||||
|
coords={
|
||||||
|
**wav.coords,
|
||||||
|
"time": arrays.create_time_dim_from_array(
|
||||||
|
times,
|
||||||
|
samplerate=samplerate,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
attrs=wav.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio_poly(
|
||||||
|
array: xr.DataArray,
|
||||||
|
sr_orig: int,
|
||||||
|
sr_new: int,
|
||||||
|
axis: int = -1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
gcd = np.gcd(sr_orig, sr_new)
|
||||||
|
return resample_poly(
|
||||||
|
array.values,
|
||||||
|
sr_new // gcd,
|
||||||
|
sr_orig // gcd,
|
||||||
|
axis=axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio_fourier(
|
||||||
|
array: xr.DataArray,
|
||||||
|
sr_orig: int,
|
||||||
|
sr_new: int,
|
||||||
|
axis: int = -1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
ratio = sr_new / sr_orig
|
||||||
|
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
|
31
batdetect2/preprocess/config.py
Normal file
31
batdetect2/preprocess/config.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.preprocess.audio import (
|
||||||
|
AudioConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
SpectrogramConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PreprocessingConfig",
|
||||||
|
"load_preprocessing_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingConfig(BaseConfig):
|
||||||
|
"""Configuration for preprocessing data."""
|
||||||
|
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_preprocessing_config(
|
||||||
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PreprocessingConfig:
|
||||||
|
return load_config(path, schema=PreprocessingConfig, field=field)
|
323
batdetect2/preprocess/spectrogram.py
Normal file
323
batdetect2/preprocess/spectrogram.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import librosa.core.spectrum
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from numpy.typing import DTypeLike
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import arrays, audio
|
||||||
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class STFTConfig(BaseConfig):
|
||||||
|
window_duration: float = Field(default=0.002, gt=0)
|
||||||
|
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||||
|
window_fn: str = "hann"
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyConfig(BaseConfig):
|
||||||
|
max_freq: int = Field(default=120_000, gt=0)
|
||||||
|
min_freq: int = Field(default=10_000, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class SpecSizeConfig(BaseConfig):
|
||||||
|
height: int = 128
|
||||||
|
"""Height of the spectrogram in pixels. This value determines the
|
||||||
|
number of frequency bands and corresponds to the vertical dimension
|
||||||
|
of the spectrogram."""
|
||||||
|
|
||||||
|
resize_factor: Optional[float] = 0.5
|
||||||
|
"""Factor by which to resize the spectrogram along the time axis.
|
||||||
|
A value of 0.5 reduces the temporal dimension by half, while a
|
||||||
|
value of 2.0 doubles it. If None, no resizing is performed."""
|
||||||
|
|
||||||
|
|
||||||
|
class LogScaleConfig(BaseConfig):
|
||||||
|
name: Literal["log"] = "log"
|
||||||
|
|
||||||
|
|
||||||
|
class PcenScaleConfig(BaseConfig):
|
||||||
|
name: Literal["pcen"] = "pcen"
|
||||||
|
time_constant: float = 0.4
|
||||||
|
hop_length: int = 512
|
||||||
|
gain: float = 0.98
|
||||||
|
bias: float = 2
|
||||||
|
power: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class AmplitudeScaleConfig(BaseConfig):
|
||||||
|
name: Literal["amplitude"] = "amplitude"
|
||||||
|
|
||||||
|
|
||||||
|
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramConfig(BaseConfig):
|
||||||
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
|
scale: Scales = Field(
|
||||||
|
default_factory=PcenScaleConfig,
|
||||||
|
discriminator="name",
|
||||||
|
)
|
||||||
|
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||||
|
denoise: bool = True
|
||||||
|
max_scale: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram(
|
||||||
|
wav: xr.DataArray,
|
||||||
|
config: Optional[SpectrogramConfig] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
|
spec = stft(
|
||||||
|
wav,
|
||||||
|
window_duration=config.stft.window_duration,
|
||||||
|
window_overlap=config.stft.window_overlap,
|
||||||
|
window_fn=config.stft.window_fn,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = crop_spectrogram_frequencies(
|
||||||
|
spec,
|
||||||
|
min_freq=config.frequencies.min_freq,
|
||||||
|
max_freq=config.frequencies.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = scale_spectrogram(spec, scale=config.scale)
|
||||||
|
|
||||||
|
if config.denoise:
|
||||||
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
|
if config.size:
|
||||||
|
spec = resize_spectrogram(
|
||||||
|
spec,
|
||||||
|
height=config.size.height,
|
||||||
|
resize_factor=config.size.resize_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.max_scale:
|
||||||
|
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||||
|
|
||||||
|
return spec.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def crop_spectrogram_frequencies(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
min_freq: int = 10_000,
|
||||||
|
max_freq: int = 120_000,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
return arrays.crop_dim(
|
||||||
|
spec,
|
||||||
|
dim="frequency",
|
||||||
|
start=min_freq,
|
||||||
|
stop=max_freq,
|
||||||
|
).astype(spec.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def stft(
|
||||||
|
wave: xr.DataArray,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
window_fn: str = "hann",
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
|
step = arrays.get_dim_step(wave, dim="time")
|
||||||
|
sampling_rate = 1 / step
|
||||||
|
|
||||||
|
nfft = int(window_duration * sampling_rate)
|
||||||
|
noverlap = int(window_overlap * nfft)
|
||||||
|
hop_len = nfft - noverlap
|
||||||
|
hop_duration = hop_len / sampling_rate
|
||||||
|
|
||||||
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
|
y=wave.data.astype(dtype),
|
||||||
|
power=1,
|
||||||
|
n_fft=nfft,
|
||||||
|
hop_length=nfft - noverlap,
|
||||||
|
center=False,
|
||||||
|
window=window_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=spec.astype(dtype),
|
||||||
|
dims=["frequency", "time"],
|
||||||
|
coords={
|
||||||
|
"frequency": arrays.create_frequency_dim_from_array(
|
||||||
|
np.linspace(
|
||||||
|
0,
|
||||||
|
sampling_rate / 2,
|
||||||
|
spec.shape[0],
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
step=sampling_rate / nfft,
|
||||||
|
),
|
||||||
|
"time": arrays.create_time_dim_from_array(
|
||||||
|
np.linspace(
|
||||||
|
start_time,
|
||||||
|
end_time - (window_duration - hop_duration),
|
||||||
|
spec.shape[1],
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
step=hop_duration,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
**wave.attrs,
|
||||||
|
"original_samplerate": sampling_rate,
|
||||||
|
"nfft": nfft,
|
||||||
|
"noverlap": noverlap,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||||
|
return xr.DataArray(
|
||||||
|
data=(spec - spec.mean("time")).clip(0),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs=spec.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
scale: Scales,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if scale.name == "log":
|
||||||
|
return scale_log(spec, dtype=dtype)
|
||||||
|
|
||||||
|
if scale.name == "pcen":
|
||||||
|
return scale_pcen(
|
||||||
|
spec,
|
||||||
|
time_constant=scale.time_constant,
|
||||||
|
hop_length=scale.hop_length,
|
||||||
|
gain=scale.gain,
|
||||||
|
power=scale.power,
|
||||||
|
bias=scale.bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def scale_pcen(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
time_constant: float = 0.4,
|
||||||
|
hop_length: int = 512,
|
||||||
|
gain: float = 0.98,
|
||||||
|
bias: float = 2,
|
||||||
|
power: float = 0.5,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
|
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||||
|
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
|
return audio.pcen(
|
||||||
|
spec * (2**31),
|
||||||
|
smooth=smoothing_constant,
|
||||||
|
gain=gain,
|
||||||
|
bias=bias,
|
||||||
|
power=power,
|
||||||
|
).astype(spec.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_log(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
|
nfft = spec.attrs["nfft"]
|
||||||
|
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||||
|
return xr.DataArray(
|
||||||
|
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs=spec.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
height: int = 128,
|
||||||
|
resize_factor: Optional[float] = 0.5,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
resize_factor = resize_factor or 1
|
||||||
|
current_width = spec.sizes["time"]
|
||||||
|
return ops.resize(
|
||||||
|
spec,
|
||||||
|
time=int(resize_factor * current_width),
|
||||||
|
frequency=height,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_spectrogram_width(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
divide_factor: int = 32,
|
||||||
|
time_period: float = 0.001,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
time_width = spec.sizes["time"]
|
||||||
|
|
||||||
|
if time_width % divide_factor == 0:
|
||||||
|
return spec
|
||||||
|
|
||||||
|
target_size = int(
|
||||||
|
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
|
||||||
|
)
|
||||||
|
extra_duration = (target_size - time_width) * time_period
|
||||||
|
_, stop = arrays.get_dim_range(spec, dim="time")
|
||||||
|
resized = ops.extend_dim(
|
||||||
|
spec,
|
||||||
|
dim="time",
|
||||||
|
stop=stop + extra_duration,
|
||||||
|
)
|
||||||
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
def duration_to_spec_width(
|
||||||
|
duration: float,
|
||||||
|
samplerate: int,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
) -> int:
|
||||||
|
samples = int(duration * samplerate)
|
||||||
|
fft_len = int(window_duration * samplerate)
|
||||||
|
fft_overlap = int(window_overlap * fft_len)
|
||||||
|
hop_len = fft_len - fft_overlap
|
||||||
|
width = (samples - fft_len + hop_len) / hop_len
|
||||||
|
return int(np.floor(width))
|
||||||
|
|
||||||
|
|
||||||
|
def spec_width_to_samples(
|
||||||
|
width: int,
|
||||||
|
samplerate: int,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
) -> int:
|
||||||
|
fft_len = int(window_duration * samplerate)
|
||||||
|
fft_overlap = int(window_overlap * fft_len)
|
||||||
|
hop_len = fft_len - fft_overlap
|
||||||
|
return width * hop_len + fft_len - hop_len
|
||||||
|
|
||||||
|
|
||||||
|
def get_spectrogram_resolution(
|
||||||
|
config: SpectrogramConfig,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
max_freq = config.frequencies.max_freq
|
||||||
|
min_freq = config.frequencies.min_freq
|
||||||
|
assert config.size is not None
|
||||||
|
|
||||||
|
spec_height = config.size.height
|
||||||
|
resize_factor = config.size.resize_factor or 1
|
||||||
|
freq_bin_width = (max_freq - min_freq) / spec_height
|
||||||
|
hop_duration = config.stft.window_duration * (
|
||||||
|
1 - config.stft.window_overlap
|
||||||
|
)
|
||||||
|
return freq_bin_width, hop_duration / resize_factor
|
76
batdetect2/preprocess/tensors.py
Normal file
76
batdetect2/preprocess/tensors.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def extend_width(
|
||||||
|
array: Union[np.ndarray, torch.Tensor],
|
||||||
|
extra: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not isinstance(array, torch.Tensor):
|
||||||
|
array = torch.Tensor(array)
|
||||||
|
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
pad = [
|
||||||
|
[0, 0] if index != axis else [0, extra]
|
||||||
|
for index in range(axis, dims)[::-1]
|
||||||
|
]
|
||||||
|
return F.pad(
|
||||||
|
array,
|
||||||
|
[x for y in pad for x in y],
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_width_divisible(
|
||||||
|
array: Union[np.ndarray, torch.Tensor],
|
||||||
|
factor: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not isinstance(array, torch.Tensor):
|
||||||
|
array = torch.Tensor(array)
|
||||||
|
|
||||||
|
width = array.shape[axis]
|
||||||
|
|
||||||
|
if width % factor == 0:
|
||||||
|
return array
|
||||||
|
|
||||||
|
extra = (-width) % factor
|
||||||
|
return extend_width(array, extra, axis=axis, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_width(
|
||||||
|
array: Union[np.ndarray, torch.Tensor],
|
||||||
|
width: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not isinstance(array, torch.Tensor):
|
||||||
|
array = torch.Tensor(array)
|
||||||
|
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
current_width = array.shape[axis]
|
||||||
|
|
||||||
|
if current_width == width:
|
||||||
|
return array
|
||||||
|
|
||||||
|
if current_width < width:
|
||||||
|
return extend_width(
|
||||||
|
array,
|
||||||
|
extra=width - current_width,
|
||||||
|
axis=axis,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
slices = [
|
||||||
|
slice(None, None) if index != axis else slice(None, width)
|
||||||
|
for index in range(dims)
|
||||||
|
]
|
||||||
|
return array[tuple(slices)]
|
88
batdetect2/terms.py
Normal file
88
batdetect2/terms.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from inspect import getmembers
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from soundevent import data, terms
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"call_type",
|
||||||
|
"individual",
|
||||||
|
"get_term_from_info",
|
||||||
|
"get_tag_from_info",
|
||||||
|
"TermInfo",
|
||||||
|
"TagInfo",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TermInfo(BaseModel):
|
||||||
|
label: Optional[str]
|
||||||
|
name: Optional[str]
|
||||||
|
uri: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TagInfo(BaseModel):
|
||||||
|
value: str
|
||||||
|
term: Optional[TermInfo] = None
|
||||||
|
key: Optional[str] = None
|
||||||
|
label: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
call_type = data.Term(
|
||||||
|
name="soundevent:call_type",
|
||||||
|
label="Call Type",
|
||||||
|
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
||||||
|
)
|
||||||
|
|
||||||
|
individual = data.Term(
|
||||||
|
name="soundevent:individual",
|
||||||
|
label="Individual",
|
||||||
|
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ALL_TERMS = [
|
||||||
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
|
call_type,
|
||||||
|
individual,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
||||||
|
for term in ALL_TERMS:
|
||||||
|
if term_info.name and term_info.name == term.name:
|
||||||
|
return term
|
||||||
|
|
||||||
|
if term_info.label and term_info.label == term.label:
|
||||||
|
return term
|
||||||
|
|
||||||
|
if term_info.uri and term_info.uri == term.uri:
|
||||||
|
return term
|
||||||
|
|
||||||
|
if term_info.name is None:
|
||||||
|
if term_info.label is None:
|
||||||
|
raise ValueError("At least one of name or label must be provided.")
|
||||||
|
|
||||||
|
term_info.name = (
|
||||||
|
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if term_info.label is None:
|
||||||
|
term_info.label = term_info.name
|
||||||
|
|
||||||
|
return data.Term(
|
||||||
|
name=term_info.name,
|
||||||
|
label=term_info.label,
|
||||||
|
uri=term_info.uri,
|
||||||
|
definition="Unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||||
|
if tag_info.term:
|
||||||
|
term = get_term_from_info(tag_info.term)
|
||||||
|
elif tag_info.key:
|
||||||
|
term = data.term_from_key(tag_info.key)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either term or key must be provided in tag info.")
|
||||||
|
|
||||||
|
return data.Tag(term=term, value=tag_info.value)
|
@ -0,0 +1,48 @@
|
|||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
AugmentationsConfig,
|
||||||
|
add_echo,
|
||||||
|
augment_example,
|
||||||
|
load_agumentation_config,
|
||||||
|
mask_frequency,
|
||||||
|
mask_time,
|
||||||
|
mix_examples,
|
||||||
|
scale_volume,
|
||||||
|
select_subclip,
|
||||||
|
warp_spectrogram,
|
||||||
|
)
|
||||||
|
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||||
|
from batdetect2.train.dataset import (
|
||||||
|
LabeledDataset,
|
||||||
|
SubclipConfig,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||||
|
from batdetect2.train.preprocess import preprocess_annotations
|
||||||
|
from batdetect2.train.targets import TargetConfig, load_target_config
|
||||||
|
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AugmentationsConfig",
|
||||||
|
"LabelConfig",
|
||||||
|
"LabeledDataset",
|
||||||
|
"SubclipConfig",
|
||||||
|
"TargetConfig",
|
||||||
|
"TrainExample",
|
||||||
|
"TrainerConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"add_echo",
|
||||||
|
"augment_example",
|
||||||
|
"load_agumentation_config",
|
||||||
|
"load_label_config",
|
||||||
|
"load_target_config",
|
||||||
|
"load_train_config",
|
||||||
|
"load_trainer_config",
|
||||||
|
"mask_frequency",
|
||||||
|
"mask_time",
|
||||||
|
"mix_examples",
|
||||||
|
"preprocess_annotations",
|
||||||
|
"scale_volume",
|
||||||
|
"select_subclip",
|
||||||
|
"train",
|
||||||
|
"warp_spectrogram",
|
||||||
|
]
|
@ -1,941 +0,0 @@
|
|||||||
"""Functions and dataloaders for training and testing the model."""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.data
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
import batdetect2.utils.audio_utils as au
|
|
||||||
from batdetect2.types import (
|
|
||||||
Annotation,
|
|
||||||
AudioLoaderAnnotationGroup,
|
|
||||||
AudioLoaderParameters,
|
|
||||||
FileAnnotation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_gt_heatmaps(
|
|
||||||
spec_op_shape: Tuple[int, int],
|
|
||||||
sampling_rate: float,
|
|
||||||
ann: AudioLoaderAnnotationGroup,
|
|
||||||
class_names: List[str],
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
max_freq: float,
|
|
||||||
min_freq: float,
|
|
||||||
resize_factor: float,
|
|
||||||
target_sigma: float,
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
|
|
||||||
"""Generate ground truth heatmaps from annotations.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec_op_shape : Tuple[int, int]
|
|
||||||
Shape of the input spectrogram.
|
|
||||||
sampling_rate : int
|
|
||||||
Sampling rate of the input audio in Hz.
|
|
||||||
ann : AnnotationGroup
|
|
||||||
Dictionary containing the annotation information.
|
|
||||||
params : HeatmapParameters
|
|
||||||
Parameters controlling the generation of the heatmaps.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
y_2d_det : np.ndarray
|
|
||||||
2D heatmap of the presence of an event.
|
|
||||||
y_2d_size : np.ndarray
|
|
||||||
2D heatmap of the size of the bounding box associated to event.
|
|
||||||
y_2d_classes : np.ndarray
|
|
||||||
3D array containing the ground-truth class probabilities for each
|
|
||||||
pixel.
|
|
||||||
ann_aug : AnnotationGroup
|
|
||||||
A dictionary containing the annotation information of the
|
|
||||||
annotations that are within the input spectrogram, augmented with
|
|
||||||
the x and y indices of their pixel location in the input spectrogram.
|
|
||||||
"""
|
|
||||||
# spec may be resized on input into the network
|
|
||||||
num_classes = len(class_names)
|
|
||||||
op_height = spec_op_shape[0]
|
|
||||||
op_width = spec_op_shape[1]
|
|
||||||
freq_per_bin = (max_freq - min_freq) / op_height
|
|
||||||
|
|
||||||
# start and end times
|
|
||||||
x_pos_start = au.time_to_x_coords(
|
|
||||||
ann["start_times"],
|
|
||||||
sampling_rate,
|
|
||||||
fft_win_length,
|
|
||||||
fft_overlap,
|
|
||||||
)
|
|
||||||
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
|
|
||||||
x_pos_end = au.time_to_x_coords(
|
|
||||||
ann["end_times"],
|
|
||||||
sampling_rate,
|
|
||||||
fft_win_length,
|
|
||||||
fft_overlap,
|
|
||||||
)
|
|
||||||
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
|
|
||||||
|
|
||||||
# location on y axis i.e. frequency
|
|
||||||
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
|
|
||||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
|
||||||
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
|
|
||||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
|
||||||
bb_widths = x_pos_end - x_pos_start
|
|
||||||
bb_heights = y_pos_low - y_pos_high
|
|
||||||
|
|
||||||
# Only include annotations that are within the input spectrogram
|
|
||||||
valid_inds = np.where(
|
|
||||||
(x_pos_start >= 0)
|
|
||||||
& (x_pos_start < op_width)
|
|
||||||
& (y_pos_low >= 0)
|
|
||||||
& (y_pos_low < (op_height - 1))
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
ann_aug: AudioLoaderAnnotationGroup = {
|
|
||||||
**ann,
|
|
||||||
"start_times": ann["start_times"][valid_inds],
|
|
||||||
"end_times": ann["end_times"][valid_inds],
|
|
||||||
"high_freqs": ann["high_freqs"][valid_inds],
|
|
||||||
"low_freqs": ann["low_freqs"][valid_inds],
|
|
||||||
"class_ids": ann["class_ids"][valid_inds],
|
|
||||||
"individual_ids": ann["individual_ids"][valid_inds],
|
|
||||||
"x_inds": x_pos_start[valid_inds],
|
|
||||||
"y_inds": y_pos_low[valid_inds],
|
|
||||||
}
|
|
||||||
|
|
||||||
# if the number of calls is only 1, then it is unique
|
|
||||||
# TODO would be better if we found these unique calls at the merging stage
|
|
||||||
if len(ann_aug["individual_ids"]) == 1:
|
|
||||||
ann_aug["individual_ids"][0] = 0
|
|
||||||
|
|
||||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
|
||||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
|
||||||
|
|
||||||
# num classes and "background" class
|
|
||||||
y_2d_classes: np.ndarray = np.zeros(
|
|
||||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
# create 2D ground truth heatmaps
|
|
||||||
for ii in valid_inds:
|
|
||||||
draw_gaussian(
|
|
||||||
y_2d_det[0, :],
|
|
||||||
(x_pos_start[ii], y_pos_low[ii]),
|
|
||||||
target_sigma,
|
|
||||||
)
|
|
||||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
|
||||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
|
||||||
|
|
||||||
cls_id = ann["class_ids"][ii]
|
|
||||||
if cls_id > -1:
|
|
||||||
draw_gaussian(
|
|
||||||
y_2d_classes[cls_id, :],
|
|
||||||
(x_pos_start[ii], y_pos_low[ii]),
|
|
||||||
target_sigma,
|
|
||||||
)
|
|
||||||
|
|
||||||
# be careful as this will have a 1.0 places where we have event but
|
|
||||||
# dont know gt class this will be masked in training anyway
|
|
||||||
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
|
||||||
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
|
||||||
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
|
||||||
|
|
||||||
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
|
||||||
|
|
||||||
|
|
||||||
def draw_gaussian(
|
|
||||||
heatmap: np.ndarray,
|
|
||||||
center: Tuple[int, int],
|
|
||||||
sigmax: float,
|
|
||||||
sigmay: Optional[float] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Draw a 2D gaussian into the heatmap.
|
|
||||||
|
|
||||||
If the gaussian center is outside the heatmap, then the gaussian is not
|
|
||||||
drawn.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
heatmap : np.ndarray
|
|
||||||
The heatmap to draw into. Should be of shape (height, width).
|
|
||||||
center : Tuple[int, int]
|
|
||||||
The center of the gaussian in (x, y) format.
|
|
||||||
sigmax : float
|
|
||||||
The standard deviation of the gaussian in the x direction.
|
|
||||||
sigmay : Optional[float], optional
|
|
||||||
The standard deviation of the gaussian in the y direction. If None,
|
|
||||||
then sigmay = sigmax, by default None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the gaussian was drawn, False if it was not (because
|
|
||||||
the center was outside the heatmap).
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# center is (x, y)
|
|
||||||
# this edits the heatmap inplace
|
|
||||||
|
|
||||||
if sigmay is None:
|
|
||||||
sigmay = sigmax
|
|
||||||
tmp_size = np.maximum(sigmax, sigmay) * 3
|
|
||||||
mu_x = int(center[0] + 0.5)
|
|
||||||
mu_y = int(center[1] + 0.5)
|
|
||||||
w, h = heatmap.shape[0], heatmap.shape[1]
|
|
||||||
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
|
||||||
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
|
||||||
|
|
||||||
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
size = 2 * tmp_size + 1
|
|
||||||
x = np.arange(0, size, 1, np.float32)
|
|
||||||
y = x[:, np.newaxis]
|
|
||||||
x0 = y0 = size // 2
|
|
||||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
|
||||||
g = np.exp(
|
|
||||||
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
|
|
||||||
)
|
|
||||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
|
||||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
|
||||||
img_x = max(0, ul[0]), min(br[0], h)
|
|
||||||
img_y = max(0, ul[1]), min(br[1], w)
|
|
||||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
|
|
||||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
|
|
||||||
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
|
||||||
"""Pad array with -1s."""
|
|
||||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
|
|
||||||
|
|
||||||
|
|
||||||
def warp_spec_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
ann: AudioLoaderAnnotationGroup,
|
|
||||||
stretch_squeeze_delta: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Warp spectrogram by randomly stretching and squeezing.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to warp.
|
|
||||||
ann: AnnotationGroup
|
|
||||||
Annotation group for the spectrogram. Must be provided to sync
|
|
||||||
the start and stop times with the spectrogram after warping.
|
|
||||||
stretch_squeeze_delta: float
|
|
||||||
Maximum amount to stretch or squeeze the spectrogram.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Warped spectrogram.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function modifies the annotation group in place.
|
|
||||||
"""
|
|
||||||
# Augment spectrogram by randomly stretch and squeezing
|
|
||||||
# NOTE this also changes the start and stop time in place
|
|
||||||
|
|
||||||
delta = stretch_squeeze_delta
|
|
||||||
op_size = (spec.shape[1], spec.shape[2])
|
|
||||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
|
||||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
|
||||||
|
|
||||||
if resize_amt >= spec.shape[2]:
|
|
||||||
spec_r = torch.cat(
|
|
||||||
(
|
|
||||||
spec,
|
|
||||||
torch.zeros(
|
|
||||||
(1, spec.shape[1], resize_amt - spec.shape[2]),
|
|
||||||
dtype=spec.dtype,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
spec_r = spec[:, :, :resize_amt]
|
|
||||||
|
|
||||||
# Resize the spectrogram
|
|
||||||
spec = F.interpolate(
|
|
||||||
spec_r.unsqueeze(0),
|
|
||||||
size=op_size,
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
# Update the start and stop times
|
|
||||||
ann["start_times"] *= 1.0 / resize_fract_r
|
|
||||||
ann["end_times"] *= 1.0 / resize_fract_r
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def mask_time_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
mask_max_time_perc: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Mask out random blocks of time.
|
|
||||||
|
|
||||||
Will randomly mask out a block of time in the spectrogram. The block
|
|
||||||
will be between 0.0 and `mask_max_time_perc` of the total time.
|
|
||||||
A random number of blocks will be masked out between 1 and 3.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to mask.
|
|
||||||
mask_max_time_perc: float
|
|
||||||
Maximum percentage of time to mask out.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Spectrogram with masked out time blocks.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function is based on the implementation in::
|
|
||||||
|
|
||||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
|
||||||
Recognition
|
|
||||||
"""
|
|
||||||
fm = torchaudio.transforms.TimeMasking(
|
|
||||||
int(spec.shape[1] * mask_max_time_perc)
|
|
||||||
)
|
|
||||||
for _ in range(np.random.randint(1, 4)):
|
|
||||||
spec = fm(spec)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def mask_freq_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
mask_max_freq_perc: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Mask out random blocks of frequency.
|
|
||||||
|
|
||||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
|
||||||
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
|
|
||||||
A random number of blocks will be masked out between 1 and 3.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to mask.
|
|
||||||
mask_max_freq_perc: float
|
|
||||||
Maximum percentage of frequency to mask out.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Spectrogram with masked out frequency blocks.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function is based on the implementation in::
|
|
||||||
|
|
||||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
|
||||||
Recognition
|
|
||||||
"""
|
|
||||||
fm = torchaudio.transforms.FrequencyMasking(
|
|
||||||
int(spec.shape[1] * mask_max_freq_perc)
|
|
||||||
)
|
|
||||||
for _ in range(np.random.randint(1, 4)):
|
|
||||||
spec = fm(spec)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def scale_vol_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
spec_amp_scaling: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Scale the volume of the spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to scale.
|
|
||||||
spec_amp_scaling: float
|
|
||||||
Maximum scaling factor.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
"""
|
|
||||||
return spec * np.random.random() * spec_amp_scaling
|
|
||||||
|
|
||||||
|
|
||||||
def echo_aug(
|
|
||||||
audio: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
echo_max_delay: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Add echo to audio.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
audio: np.ndarray
|
|
||||||
Audio to add echo to.
|
|
||||||
sampling_rate: float
|
|
||||||
Sampling rate of the audio.
|
|
||||||
echo_max_delay: float
|
|
||||||
Maximum delay of the echo in seconds.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
Audio with echo added.
|
|
||||||
"""
|
|
||||||
sample_offset = (
|
|
||||||
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
|
||||||
)
|
|
||||||
# NOTE: This seems to be wrong, as the echo should be added to the
|
|
||||||
# end of the audio, not the beginning.
|
|
||||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
def resample_aug(
|
|
||||||
audio: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
resize_factor: float,
|
|
||||||
spec_divide_factor: float,
|
|
||||||
spec_train_width: int,
|
|
||||||
aug_sampling_rates: List[int],
|
|
||||||
) -> Tuple[np.ndarray, float, float]:
|
|
||||||
"""Resample audio augmentation.
|
|
||||||
|
|
||||||
Will resample the audio to a random sampling rate from the list of
|
|
||||||
sampling rates in `aug_sampling_rates`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
audio: np.ndarray
|
|
||||||
Audio to resample.
|
|
||||||
sampling_rate: float
|
|
||||||
Original sampling rate of the audio.
|
|
||||||
fft_win_length: float
|
|
||||||
Length of the FFT window in seconds.
|
|
||||||
fft_overlap: float
|
|
||||||
Amount of overlap between FFT windows.
|
|
||||||
resize_factor: float
|
|
||||||
Factor to resize the spectrogram by.
|
|
||||||
spec_divide_factor: float
|
|
||||||
Factor to divide the spectrogram by.
|
|
||||||
spec_train_width: int
|
|
||||||
Width of the spectrogram.
|
|
||||||
aug_sampling_rates: List[int]
|
|
||||||
List of sampling rates to resample to.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio : np.ndarray
|
|
||||||
Resampled audio.
|
|
||||||
sampling_rate : float
|
|
||||||
New sampling rate.
|
|
||||||
duration : float
|
|
||||||
Duration of the audio in seconds.
|
|
||||||
"""
|
|
||||||
sampling_rate_old = sampling_rate
|
|
||||||
sampling_rate = np.random.choice(aug_sampling_rates)
|
|
||||||
audio = librosa.resample(
|
|
||||||
audio,
|
|
||||||
orig_sr=sampling_rate_old,
|
|
||||||
target_sr=sampling_rate,
|
|
||||||
res_type="polyphase",
|
|
||||||
)
|
|
||||||
|
|
||||||
audio = au.pad_audio(
|
|
||||||
audio,
|
|
||||||
sampling_rate,
|
|
||||||
fft_win_length,
|
|
||||||
fft_overlap,
|
|
||||||
resize_factor,
|
|
||||||
spec_divide_factor,
|
|
||||||
spec_train_width,
|
|
||||||
)
|
|
||||||
duration = audio.shape[0] / float(sampling_rate)
|
|
||||||
return audio, sampling_rate, duration
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
|
||||||
num_samples: int,
|
|
||||||
sampling_rate: float,
|
|
||||||
audio2: np.ndarray,
|
|
||||||
sampling_rate2: float,
|
|
||||||
) -> Tuple[np.ndarray, float]:
|
|
||||||
"""Resample audio.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
num_samples: int
|
|
||||||
Expected number of samples for the output audio.
|
|
||||||
sampling_rate: float
|
|
||||||
Original sampling rate of the audio.
|
|
||||||
audio2: np.ndarray
|
|
||||||
Audio to resample.
|
|
||||||
sampling_rate2: float
|
|
||||||
Target sampling rate of the audio.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio2 : np.ndarray
|
|
||||||
Resampled audio.
|
|
||||||
sampling_rate2 : float
|
|
||||||
New sampling rate.
|
|
||||||
"""
|
|
||||||
# resample to target sampling rate
|
|
||||||
if sampling_rate != sampling_rate2:
|
|
||||||
audio2 = librosa.resample(
|
|
||||||
audio2,
|
|
||||||
orig_sr=sampling_rate2,
|
|
||||||
target_sr=sampling_rate,
|
|
||||||
res_type="polyphase",
|
|
||||||
)
|
|
||||||
sampling_rate2 = sampling_rate
|
|
||||||
|
|
||||||
# pad or trim to the correct length
|
|
||||||
if audio2.shape[0] < num_samples:
|
|
||||||
audio2 = np.hstack(
|
|
||||||
(
|
|
||||||
audio2,
|
|
||||||
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif audio2.shape[0] > num_samples:
|
|
||||||
audio2 = audio2[:num_samples]
|
|
||||||
|
|
||||||
return audio2, sampling_rate2
|
|
||||||
|
|
||||||
|
|
||||||
def combine_audio_aug(
|
|
||||||
audio: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
ann: AudioLoaderAnnotationGroup,
|
|
||||||
audio2: np.ndarray,
|
|
||||||
sampling_rate2: float,
|
|
||||||
ann2: AudioLoaderAnnotationGroup,
|
|
||||||
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
|
|
||||||
"""Combine two audio files.
|
|
||||||
|
|
||||||
Will combine two audio files by resampling them to the same sampling rate
|
|
||||||
and then combining them with a random weight. The annotations will be
|
|
||||||
combined by taking the union of the two sets of annotations.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
audio: np.ndarray
|
|
||||||
First Audio to combine.
|
|
||||||
sampling_rate: int
|
|
||||||
Sampling rate of the first audio.
|
|
||||||
ann: AnnotationGroup
|
|
||||||
Annotations for the first audio.
|
|
||||||
audio2: np.ndarray
|
|
||||||
Second Audio to combine.
|
|
||||||
sampling_rate2: int
|
|
||||||
Sampling rate of the second audio.
|
|
||||||
ann2: AnnotationGroup
|
|
||||||
Annotations for the second audio.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio : np.ndarray
|
|
||||||
Combined audio.
|
|
||||||
ann : AnnotationGroup
|
|
||||||
Combined annotations.
|
|
||||||
"""
|
|
||||||
# resample so they are the same
|
|
||||||
audio2, sampling_rate2 = resample_audio(
|
|
||||||
audio.shape[0],
|
|
||||||
sampling_rate,
|
|
||||||
audio2,
|
|
||||||
sampling_rate2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# # set mean and std to be the same
|
|
||||||
# audio2 = (audio2 - audio2.mean())
|
|
||||||
# audio2 = (audio2/audio2.std())*audio.std()
|
|
||||||
# audio2 = audio2 + audio.mean()
|
|
||||||
|
|
||||||
if (
|
|
||||||
ann.get("annotated", False)
|
|
||||||
and (ann2.get("annotated", False))
|
|
||||||
and (sampling_rate2 == sampling_rate)
|
|
||||||
and (audio.shape[0] == audio2.shape[0])
|
|
||||||
):
|
|
||||||
comb_weight = 0.3 + np.random.random() * 0.4
|
|
||||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
|
||||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
|
||||||
for kk in ann.keys():
|
|
||||||
# when combining calls from different files, assume they come
|
|
||||||
# from different individuals
|
|
||||||
if kk == "individual_ids":
|
|
||||||
if (ann[kk] > -1).sum() > 0:
|
|
||||||
ann2[kk][ann2[kk] > -1] += (
|
|
||||||
np.max(ann[kk][ann[kk] > -1]) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
if (kk != "class_id_file") and (kk != "annotated"):
|
|
||||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
|
||||||
|
|
||||||
return audio, ann
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_annotation(
|
|
||||||
annotation: Annotation,
|
|
||||||
class_names: List[str],
|
|
||||||
) -> Annotation:
|
|
||||||
try:
|
|
||||||
class_id = class_names.index(annotation["class"])
|
|
||||||
except ValueError:
|
|
||||||
class_id = -1
|
|
||||||
|
|
||||||
ann: Annotation = {
|
|
||||||
**annotation,
|
|
||||||
"class_id": class_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
if "individual" in ann:
|
|
||||||
ann["individual"] = int(ann["individual"]) # type: ignore
|
|
||||||
|
|
||||||
return ann
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_file_annotation(
|
|
||||||
annotation: FileAnnotation,
|
|
||||||
class_names: List[str],
|
|
||||||
classes_to_ignore: List[str],
|
|
||||||
) -> AudioLoaderAnnotationGroup:
|
|
||||||
annotations = [
|
|
||||||
_prepare_annotation(ann, class_names)
|
|
||||||
for ann in annotation["annotation"]
|
|
||||||
if ann["class"] not in classes_to_ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
class_id_file = class_names.index(annotation["class_name"])
|
|
||||||
except ValueError:
|
|
||||||
class_id_file = -1
|
|
||||||
|
|
||||||
ret: AudioLoaderAnnotationGroup = {
|
|
||||||
"id": annotation["id"],
|
|
||||||
"annotated": annotation["annotated"],
|
|
||||||
"duration": annotation["duration"],
|
|
||||||
"issues": annotation["issues"],
|
|
||||||
"time_exp": annotation["time_exp"],
|
|
||||||
"class_name": annotation["class_name"],
|
|
||||||
"notes": annotation["notes"],
|
|
||||||
"annotation": annotations,
|
|
||||||
"start_times": np.array([ann["start_time"] for ann in annotations]),
|
|
||||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
|
||||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
|
||||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
|
||||||
"class_ids": np.array(
|
|
||||||
[ann.get("class_id", -1) for ann in annotations]
|
|
||||||
),
|
|
||||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
|
||||||
"class_id_file": class_id_file,
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class AudioLoader(torch.utils.data.Dataset):
|
|
||||||
"""Main AudioLoader for training and testing."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data_anns_ip: List[FileAnnotation],
|
|
||||||
params: AudioLoaderParameters,
|
|
||||||
dataset_name: Optional[str] = None,
|
|
||||||
is_train: bool = False,
|
|
||||||
return_spec_for_viz: bool = False,
|
|
||||||
):
|
|
||||||
self.is_train = is_train
|
|
||||||
self.params = params
|
|
||||||
self.return_spec_for_viz = return_spec_for_viz
|
|
||||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
|
||||||
_prepare_file_annotation(
|
|
||||||
ann,
|
|
||||||
params["class_names"],
|
|
||||||
params["classes_to_ignore"],
|
|
||||||
)
|
|
||||||
for ann in data_anns_ip
|
|
||||||
]
|
|
||||||
|
|
||||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
|
||||||
self.max_num_anns = 2 * np.max(
|
|
||||||
ann_cnt
|
|
||||||
) # x2 because we may be combining files during training
|
|
||||||
|
|
||||||
print("\n")
|
|
||||||
if dataset_name is not None:
|
|
||||||
print("Dataset : " + dataset_name)
|
|
||||||
if self.is_train:
|
|
||||||
print("Split type : train")
|
|
||||||
else:
|
|
||||||
print("Split type : test")
|
|
||||||
print("Num files : " + str(len(self.data_anns)))
|
|
||||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
|
||||||
|
|
||||||
def get_file_and_anns(
|
|
||||||
self,
|
|
||||||
index: Optional[int] = None,
|
|
||||||
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
|
|
||||||
"""Get an audio file and its annotations.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
index : int, optional
|
|
||||||
Index of the file to be loaded. If None, a random file is chosen.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio_raw : np.ndarray
|
|
||||||
Loaded audio file.
|
|
||||||
sampling_rate : float
|
|
||||||
Sampling rate of the audio file.
|
|
||||||
duration : float
|
|
||||||
Duration of the audio file in seconds.
|
|
||||||
ann : AnnotationGroup
|
|
||||||
AnnotationGroup object containing the annotations for the audio file.
|
|
||||||
"""
|
|
||||||
# if no file specified, choose random one
|
|
||||||
if index is None:
|
|
||||||
index = np.random.randint(0, len(self.data_anns))
|
|
||||||
|
|
||||||
audio_file = self.data_anns[index]["file_path"]
|
|
||||||
sampling_rate, audio_raw = au.load_audio(
|
|
||||||
audio_file,
|
|
||||||
self.data_anns[index]["time_exp"],
|
|
||||||
self.params["target_samp_rate"],
|
|
||||||
self.params["scale_raw_audio"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# copy annotation
|
|
||||||
ann = copy.deepcopy(self.data_anns[index])
|
|
||||||
# ann["annotated"] = self.data_anns[index]["annotated"]
|
|
||||||
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
|
||||||
# keys = [
|
|
||||||
# "start_times",
|
|
||||||
# "end_times",
|
|
||||||
# "high_freqs",
|
|
||||||
# "low_freqs",
|
|
||||||
# "class_ids",
|
|
||||||
# "individual_ids",
|
|
||||||
# ]
|
|
||||||
# for kk in keys:
|
|
||||||
# ann[kk] = self.data_anns[index][kk].copy()
|
|
||||||
|
|
||||||
# if train then grab a random crop
|
|
||||||
if self.is_train:
|
|
||||||
nfft = int(self.params["fft_win_length"] * sampling_rate)
|
|
||||||
noverlap = int(self.params["fft_overlap"] * nfft)
|
|
||||||
length_samples = (
|
|
||||||
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
|
|
||||||
)
|
|
||||||
|
|
||||||
if audio_raw.shape[0] - length_samples > 0:
|
|
||||||
sample_crop = np.random.randint(
|
|
||||||
audio_raw.shape[0] - length_samples
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample_crop = 0
|
|
||||||
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
|
|
||||||
ann["start_times"] = ann["start_times"] - sample_crop / float(
|
|
||||||
sampling_rate
|
|
||||||
)
|
|
||||||
ann["end_times"] = ann["end_times"] - sample_crop / float(
|
|
||||||
sampling_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
# pad audio
|
|
||||||
if self.is_train:
|
|
||||||
op_spec_target_size = self.params["spec_train_width"]
|
|
||||||
else:
|
|
||||||
op_spec_target_size = None
|
|
||||||
audio_raw = au.pad_audio(
|
|
||||||
audio_raw,
|
|
||||||
sampling_rate,
|
|
||||||
self.params["fft_win_length"],
|
|
||||||
self.params["fft_overlap"],
|
|
||||||
self.params["resize_factor"],
|
|
||||||
self.params["spec_divide_factor"],
|
|
||||||
op_spec_target_size,
|
|
||||||
)
|
|
||||||
duration = audio_raw.shape[0] / float(sampling_rate)
|
|
||||||
|
|
||||||
# sort based on time
|
|
||||||
inds = np.argsort(ann["start_times"])
|
|
||||||
for kk in ann.keys():
|
|
||||||
if (kk != "class_id_file") and (kk != "annotated"):
|
|
||||||
ann[kk] = ann[kk][inds]
|
|
||||||
|
|
||||||
return audio_raw, sampling_rate, duration, ann
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
"""Get an item from the dataset."""
|
|
||||||
# load audio file
|
|
||||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
|
||||||
|
|
||||||
# augment on raw audio
|
|
||||||
if self.is_train and self.params["augment_at_train"]:
|
|
||||||
# augment - combine with random audio file
|
|
||||||
if (
|
|
||||||
self.params["augment_at_train_combine"]
|
|
||||||
and np.random.random() < self.params["aug_prob"]
|
|
||||||
):
|
|
||||||
(
|
|
||||||
audio2,
|
|
||||||
sampling_rate2,
|
|
||||||
_,
|
|
||||||
ann2,
|
|
||||||
) = self.get_file_and_anns()
|
|
||||||
audio, ann = combine_audio_aug(
|
|
||||||
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
|
|
||||||
)
|
|
||||||
|
|
||||||
# simulate echo by adding delayed copy of the file
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
audio = echo_aug(
|
|
||||||
audio,
|
|
||||||
sampling_rate,
|
|
||||||
echo_max_delay=self.params["echo_max_delay"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# resample the audio
|
|
||||||
# if np.random.random() < self.params["aug_prob"]:
|
|
||||||
# audio, sampling_rate, duration = resample_aug(
|
|
||||||
# audio, sampling_rate, self.params
|
|
||||||
# )
|
|
||||||
|
|
||||||
# create spectrogram
|
|
||||||
spec, _ = au.generate_spectrogram(
|
|
||||||
audio,
|
|
||||||
sampling_rate,
|
|
||||||
params=dict(
|
|
||||||
fft_win_length=self.params["fft_win_length"],
|
|
||||||
fft_overlap=self.params["fft_overlap"],
|
|
||||||
max_freq=self.params["max_freq"],
|
|
||||||
min_freq=self.params["min_freq"],
|
|
||||||
spec_scale=self.params["spec_scale"],
|
|
||||||
denoise_spec_avg=self.params["denoise_spec_avg"],
|
|
||||||
max_scale_spec=self.params["max_scale_spec"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
rsf = self.params["resize_factor"]
|
|
||||||
spec_op_shape = (
|
|
||||||
int(self.params["spec_height"] * rsf),
|
|
||||||
int(spec.shape[1] * rsf),
|
|
||||||
)
|
|
||||||
|
|
||||||
# resize the spec
|
|
||||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
|
||||||
spec = F.interpolate(
|
|
||||||
spec,
|
|
||||||
size=spec_op_shape,
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
# augment spectrogram
|
|
||||||
if self.is_train and self.params["augment_at_train"]:
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = scale_vol_aug(
|
|
||||||
spec,
|
|
||||||
spec_amp_scaling=self.params["spec_amp_scaling"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = warp_spec_aug(
|
|
||||||
spec,
|
|
||||||
ann,
|
|
||||||
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = mask_time_aug(
|
|
||||||
spec,
|
|
||||||
mask_max_time_perc=self.params["mask_max_time_perc"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = mask_freq_aug(
|
|
||||||
spec,
|
|
||||||
mask_max_freq_perc=self.params["mask_max_freq_perc"],
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = {}
|
|
||||||
outputs["spec"] = spec
|
|
||||||
if self.return_spec_for_viz:
|
|
||||||
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
|
|
||||||
0
|
|
||||||
)
|
|
||||||
|
|
||||||
# create ground truth heatmaps
|
|
||||||
(
|
|
||||||
outputs["y_2d_det"],
|
|
||||||
outputs["y_2d_size"],
|
|
||||||
outputs["y_2d_classes"],
|
|
||||||
ann_aug,
|
|
||||||
) = generate_gt_heatmaps(
|
|
||||||
spec_op_shape,
|
|
||||||
sampling_rate,
|
|
||||||
ann,
|
|
||||||
class_names=self.params["class_names"],
|
|
||||||
fft_win_length=self.params["fft_win_length"],
|
|
||||||
fft_overlap=self.params["fft_overlap"],
|
|
||||||
max_freq=self.params["max_freq"],
|
|
||||||
min_freq=self.params["min_freq"],
|
|
||||||
resize_factor=self.params["resize_factor"],
|
|
||||||
target_sigma=self.params["target_sigma"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# hack to get around requirement that all vectors are the same length
|
|
||||||
# in the output batch
|
|
||||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
|
||||||
outputs["is_valid"] = pad_aray(
|
|
||||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
|
||||||
)
|
|
||||||
keys = [
|
|
||||||
"class_ids",
|
|
||||||
"individual_ids",
|
|
||||||
"x_inds",
|
|
||||||
"y_inds",
|
|
||||||
"start_times",
|
|
||||||
"end_times",
|
|
||||||
"low_freqs",
|
|
||||||
"high_freqs",
|
|
||||||
]
|
|
||||||
for kk in keys:
|
|
||||||
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
|
|
||||||
|
|
||||||
# convert to pytorch
|
|
||||||
for kk in outputs.keys():
|
|
||||||
if type(outputs[kk]) != torch.Tensor:
|
|
||||||
outputs[kk] = torch.from_numpy(outputs[kk])
|
|
||||||
|
|
||||||
# scalars
|
|
||||||
outputs["class_id_file"] = ann["class_id_file"]
|
|
||||||
outputs["annotated"] = ann["annotated"]
|
|
||||||
outputs["duration"] = duration
|
|
||||||
outputs["sampling_rate"] = sampling_rate
|
|
||||||
outputs["file_id"] = index
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""Denotes the total number of samples."""
|
|
||||||
return len(self.data_anns)
|
|
@ -1,136 +1,214 @@
|
|||||||
from functools import wraps
|
from typing import Callable, Optional, Union
|
||||||
from typing import Callable, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from pydantic import Field
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent import arrays, data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||||
|
from batdetect2.preprocess.arrays import adjust_width
|
||||||
|
|
||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
|
|
||||||
AUGMENTATION_PROBABILITY = 0.2
|
__all__ = [
|
||||||
MAX_DELAY = 0.005
|
"AugmentationsConfig",
|
||||||
STRETCH_SQUEEZE_DELTA = 0.04
|
"load_agumentation_config",
|
||||||
MASK_MAX_TIME_PERC: float = 0.05
|
"select_subclip",
|
||||||
MASK_MAX_FREQ_PERC: float = 0.10
|
"mix_examples",
|
||||||
|
"add_echo",
|
||||||
|
"scale_volume",
|
||||||
|
"warp_spectrogram",
|
||||||
|
"mask_time",
|
||||||
|
"mask_frequency",
|
||||||
|
"augment_example",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def maybe_apply(
|
class BaseAugmentationConfig(BaseConfig):
|
||||||
augmentation: Callable,
|
enable: bool = True
|
||||||
prob: float = AUGMENTATION_PROBABILITY,
|
probability: float = 0.2
|
||||||
) -> Callable:
|
|
||||||
"""Apply an augmentation with a given probability."""
|
|
||||||
|
|
||||||
@wraps(augmentation)
|
|
||||||
def _augmentation(x):
|
|
||||||
if np.random.rand() > prob:
|
|
||||||
return x
|
|
||||||
return augmentation(x)
|
|
||||||
|
|
||||||
return _augmentation
|
|
||||||
|
|
||||||
|
|
||||||
def select_random_subclip(
|
def select_subclip(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
|
start_time: Optional[float] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
proportion: float = 0.9,
|
width: Optional[int] = None,
|
||||||
|
random: bool = False,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Select a random subclip from a clip."""
|
"""Select a random subclip from a clip."""
|
||||||
|
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||||
|
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||||
|
|
||||||
time_coords = train_example.coords["time"]
|
if width is None:
|
||||||
|
if duration is None:
|
||||||
|
raise ValueError("Either duration or width must be provided")
|
||||||
|
|
||||||
start_time = time_coords.attrs.get("min", time_coords.min())
|
width = int(np.floor(duration / step))
|
||||||
end_time = time_coords.attrs.get("max", time_coords.max())
|
|
||||||
|
|
||||||
if duration is None:
|
if duration is None:
|
||||||
duration = (end_time - start_time) * proportion
|
duration = width * step
|
||||||
|
|
||||||
start_time = np.random.uniform(start_time, end_time - duration)
|
if start_time is None:
|
||||||
return train_example.sel(time=slice(start_time, start_time + duration))
|
if random:
|
||||||
|
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||||
|
else:
|
||||||
|
start_time = start
|
||||||
|
|
||||||
|
if start_time + duration > stop:
|
||||||
|
return example
|
||||||
|
|
||||||
|
start_index = arrays.get_coord_index(
|
||||||
|
example, # type: ignore
|
||||||
|
"time",
|
||||||
|
start_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
end_index = start_index + width - 1
|
||||||
|
|
||||||
|
start_time = example.time.values[start_index]
|
||||||
|
end_time = example.time.values[end_index]
|
||||||
|
|
||||||
|
return example.sel(
|
||||||
|
time=slice(start_time, end_time),
|
||||||
|
audio_time=slice(start_time, end_time + step),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def combine_audio(
|
class MixAugmentationConfig(BaseAugmentationConfig):
|
||||||
audio1: xr.DataArray,
|
min_weight: float = 0.3
|
||||||
audio2: xr.DataArray,
|
max_weight: float = 0.7
|
||||||
alpha: Optional[float] = None,
|
|
||||||
min_alpha: float = 0.3,
|
|
||||||
max_alpha: float = 0.7,
|
def mix_examples(
|
||||||
) -> xr.DataArray:
|
example: xr.Dataset,
|
||||||
|
other: xr.Dataset,
|
||||||
|
weight: Optional[float] = None,
|
||||||
|
min_weight: float = 0.3,
|
||||||
|
max_weight: float = 0.7,
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
) -> xr.Dataset:
|
||||||
"""Combine two audio clips."""
|
"""Combine two audio clips."""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
|
||||||
if alpha is None:
|
if weight is None:
|
||||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
weight = np.random.uniform(min_weight, max_weight)
|
||||||
|
|
||||||
return alpha * audio1 + (1 - alpha) * audio2.data
|
audio1 = example["audio"]
|
||||||
|
audio2 = adjust_width(other["audio"].values, len(audio1))
|
||||||
|
|
||||||
|
combined = weight * audio1 + (1 - weight) * audio2
|
||||||
|
|
||||||
|
spectrogram = compute_spectrogram(
|
||||||
|
combined.rename({"audio_time": "time"}),
|
||||||
|
config=config.spectrogram,
|
||||||
|
).data
|
||||||
|
|
||||||
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
|
# spectrogram computed from the subclip's audio. This is due to a
|
||||||
|
# simplification in the subclip process: It doesn't account for the
|
||||||
|
# spectrogram parameters to precisely determine the corresponding audio
|
||||||
|
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||||
|
# as needed.
|
||||||
|
previous_width = len(example["time"])
|
||||||
|
spectrogram = adjust_width(spectrogram, previous_width)
|
||||||
|
|
||||||
|
detection_heatmap = xr.apply_ufunc(
|
||||||
|
np.maximum,
|
||||||
|
example["detection"],
|
||||||
|
adjust_width(other["detection"].values, previous_width),
|
||||||
|
)
|
||||||
|
|
||||||
|
class_heatmap = xr.apply_ufunc(
|
||||||
|
np.maximum,
|
||||||
|
example["class"],
|
||||||
|
adjust_width(other["class"].values, previous_width),
|
||||||
|
)
|
||||||
|
|
||||||
|
size_heatmap = example["size"] + adjust_width(
|
||||||
|
other["size"].values, previous_width
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.Dataset(
|
||||||
|
{
|
||||||
|
"audio": combined,
|
||||||
|
"spectrogram": xr.DataArray(
|
||||||
|
data=spectrogram,
|
||||||
|
dims=example["spectrogram"].dims,
|
||||||
|
coords=example["spectrogram"].coords,
|
||||||
|
),
|
||||||
|
"detection": detection_heatmap,
|
||||||
|
"class": class_heatmap,
|
||||||
|
"size": size_heatmap,
|
||||||
|
},
|
||||||
|
attrs=example.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# def random_mix(
|
class EchoAugmentationConfig(BaseAugmentationConfig):
|
||||||
# audio: xr.DataArray,
|
max_delay: float = 0.005
|
||||||
# clip: data.ClipAnnotation,
|
min_weight: float = 0.0
|
||||||
# provider: Optional[ClipProvider] = None,
|
max_weight: float = 1.0
|
||||||
# alpha: Optional[float] = None,
|
|
||||||
# min_alpha: float = 0.3,
|
|
||||||
# max_alpha: float = 0.7,
|
|
||||||
# join_annotations: bool = True,
|
|
||||||
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
# """Mix two audio clips."""
|
|
||||||
# if provider is None:
|
|
||||||
# raise ValueError("No audio provider given.")
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# other_audio, other_clip = provider(clip)
|
|
||||||
# except (StopIteration, ValueError):
|
|
||||||
# raise ValueError("No more audio sources available.")
|
|
||||||
#
|
|
||||||
# new_audio = combine_audio(
|
|
||||||
# audio,
|
|
||||||
# other_audio,
|
|
||||||
# alpha=alpha,
|
|
||||||
# min_alpha=min_alpha,
|
|
||||||
# max_alpha=max_alpha,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# if join_annotations:
|
|
||||||
# clip = clip.model_copy(
|
|
||||||
# update=dict(
|
|
||||||
# sound_events=clip.sound_events + other_clip.sound_events,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# return new_audio, clip
|
|
||||||
|
|
||||||
|
|
||||||
def add_echo(
|
def add_echo(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
delay: Optional[float] = None,
|
delay: Optional[float] = None,
|
||||||
alpha: Optional[float] = None,
|
weight: Optional[float] = None,
|
||||||
min_alpha: float = 0.0,
|
min_weight: float = 0.1,
|
||||||
max_alpha: float = 1.0,
|
max_weight: float = 1.0,
|
||||||
max_delay: float = MAX_DELAY,
|
max_delay: float = 0.005,
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Add a delay to the audio."""
|
"""Add a delay to the audio."""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
|
||||||
if delay is None:
|
if delay is None:
|
||||||
delay = np.random.uniform(0, max_delay)
|
delay = np.random.uniform(0, max_delay)
|
||||||
|
|
||||||
if alpha is None:
|
if weight is None:
|
||||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
weight = np.random.uniform(min_weight, max_weight)
|
||||||
|
|
||||||
spec = train_example["spectrogram"]
|
audio = example["audio"]
|
||||||
|
step = arrays.get_dim_step(audio, "audio_time")
|
||||||
|
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||||
|
audio = audio + weight * audio_delay
|
||||||
|
|
||||||
time_coords = spec.coords["time"]
|
spectrogram = compute_spectrogram(
|
||||||
start_time = time_coords.attrs["min"]
|
audio.rename({"audio_time": "time"}),
|
||||||
end_time = time_coords.attrs["max"]
|
config=config.spectrogram,
|
||||||
step = (end_time - start_time) / time_coords.size
|
).data
|
||||||
|
|
||||||
spec_delay = spec.shift(time=int(delay / step), fill_value=0)
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
|
# spectrogram computed from the subclip's audio. This is due to a
|
||||||
|
# simplification in the subclip process: It doesn't account for the
|
||||||
|
# spectrogram parameters to precisely determine the corresponding audio
|
||||||
|
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||||
|
# as needed.
|
||||||
|
spectrogram = adjust_width(
|
||||||
|
spectrogram,
|
||||||
|
example["spectrogram"].sizes["time"],
|
||||||
|
)
|
||||||
|
|
||||||
return train_example.assign(spectrogram=spec + alpha * spec_delay)
|
return example.assign(
|
||||||
|
audio=audio,
|
||||||
|
spectrogram=xr.DataArray(
|
||||||
|
data=spectrogram,
|
||||||
|
dims=example["spectrogram"].dims,
|
||||||
|
coords=example["spectrogram"].coords,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VolumeAugmentationConfig(BaseAugmentationConfig):
|
||||||
|
min_scaling: float = 0.0
|
||||||
|
max_scaling: float = 2.0
|
||||||
|
|
||||||
|
|
||||||
def scale_volume(
|
def scale_volume(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
factor: Optional[float] = None,
|
factor: Optional[float] = None,
|
||||||
max_scaling: float = 2,
|
max_scaling: float = 2,
|
||||||
min_scaling: float = 0,
|
min_scaling: float = 0,
|
||||||
@ -139,106 +217,227 @@ def scale_volume(
|
|||||||
if factor is None:
|
if factor is None:
|
||||||
factor = np.random.uniform(min_scaling, max_scaling)
|
factor = np.random.uniform(min_scaling, max_scaling)
|
||||||
|
|
||||||
return train_example.assign(
|
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||||
spectrogram=train_example["spectrogram"] * factor
|
|
||||||
)
|
|
||||||
|
class WarpAugmentationConfig(BaseAugmentationConfig):
|
||||||
|
delta: float = 0.04
|
||||||
|
|
||||||
|
|
||||||
def warp_spectrogram(
|
def warp_spectrogram(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
factor: Optional[float] = None,
|
factor: Optional[float] = None,
|
||||||
delta: float = STRETCH_SQUEEZE_DELTA,
|
delta: float = 0.04,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Warp a spectrogram."""
|
"""Warp a spectrogram."""
|
||||||
if factor is None:
|
if factor is None:
|
||||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||||
|
|
||||||
time_coords = train_example.coords["time"]
|
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||||
start_time = time_coords.attrs["min"]
|
|
||||||
end_time = time_coords.attrs["max"]
|
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
|
|
||||||
new_time = np.linspace(
|
new_time = np.linspace(
|
||||||
start_time,
|
start_time,
|
||||||
start_time + duration * factor,
|
start_time + duration * factor,
|
||||||
train_example.time.size,
|
example.time.size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_example.interp(time=new_time)
|
spectrogram = (
|
||||||
|
example["spectrogram"]
|
||||||
|
.interp(
|
||||||
|
coords={"time": new_time},
|
||||||
|
method="linear",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.clip(min=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
detection = example["detection"].interp(
|
||||||
|
time=new_time,
|
||||||
|
method="nearest",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
classification = example["class"].interp(
|
||||||
|
time=new_time,
|
||||||
|
method="nearest",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
size = example["size"].interp(
|
||||||
|
time=new_time,
|
||||||
|
method="nearest",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return example.assign(
|
||||||
|
{
|
||||||
|
"time": new_time,
|
||||||
|
"spectrogram": spectrogram,
|
||||||
|
"detection": detection,
|
||||||
|
"class": classification,
|
||||||
|
"size": size,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mask_axis(
|
def mask_axis(
|
||||||
train_example: xr.Dataset,
|
array: xr.DataArray,
|
||||||
dim: str,
|
dim: str,
|
||||||
start: float,
|
start: float,
|
||||||
end: float,
|
end: float,
|
||||||
mask_all: bool = False,
|
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
|
||||||
mask_value: float = 0,
|
) -> xr.DataArray:
|
||||||
) -> xr.Dataset:
|
if dim not in array.dims:
|
||||||
if dim not in train_example.dims:
|
|
||||||
raise ValueError(f"Axis {dim} not found in array")
|
raise ValueError(f"Axis {dim} not found in array")
|
||||||
|
|
||||||
coord = train_example.coords[dim]
|
coord = array.coords[dim]
|
||||||
condition = (coord < start) | (coord > end)
|
condition = (coord < start) | (coord > end)
|
||||||
|
|
||||||
if mask_all:
|
if callable(mask_value):
|
||||||
return train_example.where(condition, other=mask_value)
|
mask_value = mask_value(array)
|
||||||
|
|
||||||
return train_example.assign(
|
return array.where(condition, other=mask_value)
|
||||||
spectrogram=train_example.spectrogram.where(
|
|
||||||
condition, other=mask_value
|
|
||||||
)
|
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
|
||||||
)
|
max_perc: float = 0.05
|
||||||
|
max_masks: int = 3
|
||||||
|
|
||||||
|
|
||||||
def mask_time(
|
def mask_time(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
max_time_mask: float = MASK_MAX_TIME_PERC,
|
max_perc: float = 0.05,
|
||||||
max_num_masks: int = 3,
|
max_mask: int = 3,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Mask a random section of the time axis."""
|
"""Mask a random section of the time axis."""
|
||||||
|
num_masks = np.random.randint(1, max_mask + 1)
|
||||||
|
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||||
|
|
||||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
spectrogram = example["spectrogram"]
|
||||||
|
|
||||||
time_coord = train_example.coords["time"]
|
|
||||||
start_time = time_coord.attrs.get("min", time_coord.min())
|
|
||||||
end_time = time_coord.attrs.get("max", time_coord.max())
|
|
||||||
|
|
||||||
for _ in range(num_masks):
|
for _ in range(num_masks):
|
||||||
mask_size = np.random.uniform(0, max_time_mask)
|
mask_size = np.random.uniform(0, max_perc) * (end_time - start_time)
|
||||||
start = np.random.uniform(start_time, end_time - mask_size)
|
start = np.random.uniform(start_time, end_time - mask_size)
|
||||||
end = start + mask_size
|
end = start + mask_size
|
||||||
train_example = mask_axis(train_example, "time", start, end)
|
spectrogram = mask_axis(spectrogram, "time", start, end)
|
||||||
|
|
||||||
return train_example
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
|
||||||
|
max_perc: float = 0.10
|
||||||
|
max_masks: int = 3
|
||||||
|
|
||||||
|
|
||||||
def mask_frequency(
|
def mask_frequency(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
max_perc: float = 0.10,
|
||||||
max_num_masks: int = 3,
|
max_masks: int = 3,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Mask a random section of the frequency axis."""
|
"""Mask a random section of the frequency axis."""
|
||||||
|
num_masks = np.random.randint(1, max_masks + 1)
|
||||||
|
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
|
||||||
|
|
||||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
spectrogram = example["spectrogram"]
|
||||||
|
|
||||||
freq_coord = train_example.coords["frequency"]
|
|
||||||
min_freq = freq_coord.min()
|
|
||||||
max_freq = freq_coord.max()
|
|
||||||
|
|
||||||
for _ in range(num_masks):
|
for _ in range(num_masks):
|
||||||
mask_size = np.random.uniform(0, max_freq_mask)
|
mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq)
|
||||||
start = np.random.uniform(min_freq, max_freq - mask_size)
|
start = np.random.uniform(min_freq, max_freq - mask_size)
|
||||||
end = start + mask_size
|
end = start + mask_size
|
||||||
train_example = mask_axis(train_example, "frequency", start, end)
|
spectrogram = mask_axis(spectrogram, "frequency", start, end)
|
||||||
|
|
||||||
return train_example
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
AUGMENTATIONS: List[Augmentation] = [
|
class AugmentationsConfig(BaseConfig):
|
||||||
select_random_subclip,
|
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
||||||
add_echo,
|
echo: EchoAugmentationConfig = Field(
|
||||||
scale_volume,
|
default_factory=EchoAugmentationConfig
|
||||||
mask_time,
|
)
|
||||||
mask_frequency,
|
volume: VolumeAugmentationConfig = Field(
|
||||||
]
|
default_factory=VolumeAugmentationConfig
|
||||||
|
)
|
||||||
|
warp: WarpAugmentationConfig = Field(
|
||||||
|
default_factory=WarpAugmentationConfig
|
||||||
|
)
|
||||||
|
time_mask: TimeMaskAugmentationConfig = Field(
|
||||||
|
default_factory=TimeMaskAugmentationConfig
|
||||||
|
)
|
||||||
|
frequency_mask: FrequencyMaskAugmentationConfig = Field(
|
||||||
|
default_factory=FrequencyMaskAugmentationConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_agumentation_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> AugmentationsConfig:
|
||||||
|
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def should_apply(config: BaseAugmentationConfig) -> bool:
|
||||||
|
if not config.enable:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return np.random.uniform() < config.probability
|
||||||
|
|
||||||
|
|
||||||
|
def augment_example(
|
||||||
|
example: xr.Dataset,
|
||||||
|
config: AugmentationsConfig,
|
||||||
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
|
others: Optional[Callable[[], xr.Dataset]] = None,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
if should_apply(config.mix) and (others is not None):
|
||||||
|
other = others()
|
||||||
|
example = mix_examples(
|
||||||
|
example,
|
||||||
|
other,
|
||||||
|
min_weight=config.mix.min_weight,
|
||||||
|
max_weight=config.mix.max_weight,
|
||||||
|
config=preprocessing_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.echo):
|
||||||
|
example = add_echo(
|
||||||
|
example,
|
||||||
|
max_delay=config.echo.max_delay,
|
||||||
|
min_weight=config.echo.min_weight,
|
||||||
|
max_weight=config.echo.max_weight,
|
||||||
|
config=preprocessing_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.volume):
|
||||||
|
example = scale_volume(
|
||||||
|
example,
|
||||||
|
max_scaling=config.volume.max_scaling,
|
||||||
|
min_scaling=config.volume.min_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.warp):
|
||||||
|
example = warp_spectrogram(
|
||||||
|
example,
|
||||||
|
delta=config.warp.delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.time_mask):
|
||||||
|
example = mask_time(
|
||||||
|
example,
|
||||||
|
max_perc=config.time_mask.max_perc,
|
||||||
|
max_mask=config.time_mask.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.frequency_mask):
|
||||||
|
example = mask_frequency(
|
||||||
|
example,
|
||||||
|
max_perc=config.frequency_mask.max_perc,
|
||||||
|
max_masks=config.frequency_mask.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
return example
|
||||||
|
31
batdetect2/train/config.py
Normal file
31
batdetect2/train/config.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OptimizerConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"load_train_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerConfig(BaseConfig):
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
t_max: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseConfig):
|
||||||
|
batch_size: int = 32
|
||||||
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_config(
|
||||||
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> TrainingConfig:
|
||||||
|
return load_config(path, schema=TrainingConfig, field=field)
|
@ -1,12 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union
|
from typing import NamedTuple, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.preprocess.tensors import adjust_width
|
||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
AugmentationsConfig,
|
||||||
|
augment_example,
|
||||||
|
select_subclip,
|
||||||
|
)
|
||||||
from batdetect2.train.preprocess import PreprocessingConfig
|
from batdetect2.train.preprocess import PreprocessingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -26,74 +35,113 @@ class TrainExample(NamedTuple):
|
|||||||
idx: torch.Tensor
|
idx: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
class SubclipConfig(BaseConfig):
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
duration: Optional[float] = None
|
||||||
|
width: int = 512
|
||||||
|
random: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetConfig(BaseConfig):
|
||||||
|
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||||
|
augmentation: AugmentationsConfig = Field(
|
||||||
|
default_factory=AugmentationsConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filenames: Sequence[PathLike],
|
filenames: Sequence[PathLike],
|
||||||
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
|
subclip: Optional[SubclipConfig] = None,
|
||||||
|
augmentation: Optional[AugmentationsConfig] = None,
|
||||||
|
preprocessing: Optional[PreprocessingConfig] = None,
|
||||||
):
|
):
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.transform = transform
|
self.subclip = subclip
|
||||||
|
self.augmentation = augmentation
|
||||||
|
self.preprocessing = preprocessing or PreprocessingConfig()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.filenames)
|
return len(self.filenames)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
data = self.load(idx)
|
dataset = self.get_dataset(idx)
|
||||||
|
|
||||||
|
if self.subclip:
|
||||||
|
dataset = select_subclip(
|
||||||
|
dataset,
|
||||||
|
duration=self.subclip.duration,
|
||||||
|
width=self.subclip.width,
|
||||||
|
random=self.subclip.random,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.augmentation:
|
||||||
|
dataset = augment_example(
|
||||||
|
dataset,
|
||||||
|
self.augmentation,
|
||||||
|
preprocessing_config=self.preprocessing,
|
||||||
|
others=self.get_random_example,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainExample(
|
return TrainExample(
|
||||||
spec=data["spectrogram"],
|
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||||
detection_heatmap=data["detection"],
|
detection_heatmap=self.to_tensor(dataset["detection"]),
|
||||||
class_heatmap=data["class"],
|
class_heatmap=self.to_tensor(dataset["class"]),
|
||||||
size_heatmap=data["size"],
|
size_heatmap=self.to_tensor(dataset["size"]),
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
def from_directory(
|
||||||
return cls(get_files(directory, extension))
|
cls,
|
||||||
|
directory: PathLike,
|
||||||
|
extension: str = ".nc",
|
||||||
|
subclip: Optional[SubclipConfig] = None,
|
||||||
|
augmentation: Optional[AugmentationsConfig] = None,
|
||||||
|
preprocessing: Optional[PreprocessingConfig] = None,
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
get_files(directory, extension),
|
||||||
|
subclip=subclip,
|
||||||
|
augmentation=augmentation,
|
||||||
|
preprocessing=preprocessing,
|
||||||
|
)
|
||||||
|
|
||||||
def load(self, idx) -> Dict[str, torch.Tensor]:
|
def get_random_example(self) -> xr.Dataset:
|
||||||
|
idx = np.random.randint(0, len(self))
|
||||||
dataset = self.get_dataset(idx)
|
dataset = self.get_dataset(idx)
|
||||||
return {
|
|
||||||
"spectrogram": torch.tensor(
|
|
||||||
dataset["spectrogram"].values
|
|
||||||
).unsqueeze(0),
|
|
||||||
"detection": torch.tensor(dataset["detection"].values),
|
|
||||||
"class": torch.tensor(dataset["class"].values),
|
|
||||||
"size": torch.tensor(dataset["size"].values),
|
|
||||||
}
|
|
||||||
|
|
||||||
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset:
|
if self.subclip:
|
||||||
if self.transform is not None:
|
dataset = select_subclip(
|
||||||
return self.transform(dataset)
|
dataset,
|
||||||
|
duration=self.subclip.duration,
|
||||||
|
width=self.subclip.width,
|
||||||
|
random=self.subclip.random,
|
||||||
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def get_dataset(self, idx):
|
def get_dataset(self, idx) -> xr.Dataset:
|
||||||
return xr.open_dataset(self.filenames[idx])
|
return xr.open_dataset(self.filenames[idx])
|
||||||
|
|
||||||
def get_spectrogram(self, idx):
|
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||||
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
return data.ClipAnnotation.model_validate_json(
|
||||||
|
self.get_dataset(idx).attrs["clip_annotation"]
|
||||||
|
)
|
||||||
|
|
||||||
def get_detection_mask(self, idx):
|
def to_tensor(
|
||||||
return xr.open_dataset(self.filenames[idx])["detection"]
|
self,
|
||||||
|
array: xr.DataArray,
|
||||||
|
dtype=np.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tensor = torch.tensor(array.values.astype(dtype))
|
||||||
|
|
||||||
def get_class_mask(self, idx):
|
if not self.subclip:
|
||||||
return xr.open_dataset(self.filenames[idx])["class"]
|
return tensor
|
||||||
|
|
||||||
def get_size_mask(self, idx):
|
width = self.subclip.width
|
||||||
return xr.open_dataset(self.filenames[idx])["size"]
|
return adjust_width(tensor, width)
|
||||||
|
|
||||||
def get_clip_annotation(self, idx):
|
|
||||||
filename = self.filenames[idx]
|
|
||||||
dataset = xr.open_dataset(filename)
|
|
||||||
clip_annotation = dataset.attrs["clip_annotation"]
|
|
||||||
return data.ClipAnnotation.model_validate_json(clip_annotation)
|
|
||||||
|
|
||||||
def get_preprocessing_configuration(self, idx):
|
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||||
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
return PreprocessingConfig.model_validate_json(config)
|
|
||||||
|
@ -1,27 +1,43 @@
|
|||||||
from typing import Tuple
|
from collections.abc import Iterable
|
||||||
|
from typing import Callable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import data, geometry, arrays
|
from soundevent import arrays, data, geometry
|
||||||
from soundevent.geometry.operations import Positions
|
from soundevent.geometry.operations import Positions
|
||||||
from soundevent.types import ClassMapper
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassMapper",
|
"HeatmapsConfig",
|
||||||
|
"LabelConfig",
|
||||||
"generate_heatmaps",
|
"generate_heatmaps",
|
||||||
|
"load_label_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
TARGET_SIGMA = 3.0
|
class HeatmapsConfig(BaseConfig):
|
||||||
|
position: Positions = "bottom-left"
|
||||||
|
sigma: float = 3.0
|
||||||
|
time_scale: float = 1000.0
|
||||||
|
frequency_scale: float = 1 / 859.375
|
||||||
|
|
||||||
|
|
||||||
|
class LabelConfig(BaseConfig):
|
||||||
|
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
clip_annotation: data.ClipAnnotation,
|
sound_events: Sequence[data.SoundEventAnnotation],
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
class_mapper: ClassMapper,
|
class_names: List[str],
|
||||||
target_sigma: float = TARGET_SIGMA,
|
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
|
||||||
|
target_sigma: float = 3.0,
|
||||||
position: Positions = "bottom-left",
|
position: Positions = "bottom-left",
|
||||||
|
time_scale: float = 1000.0,
|
||||||
|
frequency_scale: float = 1 / 859.375,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||||
shape = dict(zip(spec.dims, spec.shape))
|
shape = dict(zip(spec.dims, spec.shape))
|
||||||
@ -31,20 +47,13 @@ def generate_heatmaps(
|
|||||||
"Spectrogram must have time and frequency dimensions."
|
"Spectrogram must have time and frequency dimensions."
|
||||||
)
|
)
|
||||||
|
|
||||||
time_duration = arrays.get_dim_width(spec, dim="time")
|
|
||||||
freq_bandwidth = arrays.get_dim_width(spec, dim="frequency")
|
|
||||||
|
|
||||||
# Compute the size factors
|
|
||||||
time_scale = 1 / time_duration
|
|
||||||
frequency_scale = 1 / freq_bandwidth
|
|
||||||
|
|
||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||||
class_heatmap = xr.DataArray(
|
class_heatmap = xr.DataArray(
|
||||||
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
||||||
dims=["category", *spec.dims],
|
dims=["category", *spec.dims],
|
||||||
coords={
|
coords={
|
||||||
"category": class_mapper.class_labels,
|
"category": [*class_names],
|
||||||
**spec.coords,
|
**spec.coords,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -57,9 +66,8 @@ def generate_heatmaps(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in sound_events:
|
||||||
geom = sound_event_annotation.sound_event.geometry
|
geom = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
if geom is None:
|
if geom is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -67,23 +75,29 @@ def generate_heatmaps(
|
|||||||
time, frequency = geometry.get_geometry_point(geom, position=position)
|
time, frequency = geometry.get_geometry_point(geom, position=position)
|
||||||
|
|
||||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||||
detection_heatmap = arrays.set_value_at_pos(
|
try:
|
||||||
detection_heatmap,
|
detection_heatmap = arrays.set_value_at_pos(
|
||||||
1.0,
|
detection_heatmap,
|
||||||
time=time,
|
1.0,
|
||||||
frequency=frequency,
|
time=time,
|
||||||
)
|
frequency=frequency,
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
# Skip the sound event if the position is outside the spectrogram
|
||||||
|
continue
|
||||||
|
|
||||||
# Set the size of the sound event at the position in the size heatmap
|
# Set the size of the sound event at the position in the size heatmap
|
||||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
geom
|
geom
|
||||||
)
|
)
|
||||||
|
|
||||||
size = np.array(
|
size = np.array(
|
||||||
[
|
[
|
||||||
(end_time - start_time) * time_scale,
|
(end_time - start_time) * time_scale,
|
||||||
(high_freq - low_freq) * frequency_scale,
|
(high_freq - low_freq) * frequency_scale,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
size_heatmap = arrays.set_value_at_pos(
|
size_heatmap = arrays.set_value_at_pos(
|
||||||
size_heatmap,
|
size_heatmap,
|
||||||
size,
|
size,
|
||||||
@ -92,14 +106,12 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
class_name = class_mapper.transform(sound_event_annotation)
|
class_name = encoder(sound_event_annotation.tags)
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None:
|
||||||
# If the label is None skip the sound event
|
# If the label is None skip the sound event
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Set 1.0 at the position and category of the sound event in the class
|
|
||||||
# heatmap
|
|
||||||
class_heatmap = arrays.set_value_at_pos(
|
class_heatmap = arrays.set_value_at_pos(
|
||||||
class_heatmap,
|
class_heatmap,
|
||||||
1.0,
|
1.0,
|
||||||
@ -130,3 +142,9 @@ def generate_heatmaps(
|
|||||||
).fillna(0.0)
|
).fillna(0.0)
|
||||||
|
|
||||||
return detection_heatmap, class_heatmap, size_heatmap
|
return detection_heatmap, class_heatmap, size_heatmap
|
||||||
|
|
||||||
|
|
||||||
|
def load_label_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> LabelConfig:
|
||||||
|
return load_config(path, schema=LabelConfig, field=field)
|
82
batdetect2/train/legacy/train.py
Normal file
82
batdetect2/train/legacy/train.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from typing import Callable, NamedTuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.models.typing import DetectionModel
|
||||||
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
|
|
||||||
|
|
||||||
|
class TrainInputs(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def train_loop(
|
||||||
|
model: DetectionModel,
|
||||||
|
train_dataset: LabeledDataset[TrainInputs],
|
||||||
|
validation_dataset: LabeledDataset[TrainInputs],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_epochs: int = 100,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
):
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||||
|
scheduler = CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
num_epochs * len(train_loader),
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
train_loss = train_single_epoch(
|
||||||
|
model,
|
||||||
|
train_loader,
|
||||||
|
optimizer,
|
||||||
|
device,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_single_epoch(
|
||||||
|
model: DetectionModel,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
optimizer: Adam,
|
||||||
|
device: torch.device,
|
||||||
|
scheduler: CosineAnnealingLR,
|
||||||
|
):
|
||||||
|
model.train()
|
||||||
|
train_loss = tu.AverageMeter()
|
||||||
|
|
||||||
|
for batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
spec = batch.spec.to(device)
|
||||||
|
detection_heatmap = batch.detection_heatmap.to(device)
|
||||||
|
class_heatmap = batch.class_heatmap.to(device)
|
||||||
|
size_heatmap = batch.size_heatmap.to(device)
|
||||||
|
|
||||||
|
outputs = model(spec)
|
||||||
|
|
||||||
|
loss = loss_fun(
|
||||||
|
outputs,
|
||||||
|
gt_det,
|
||||||
|
gt_size,
|
||||||
|
gt_class,
|
||||||
|
det_criterion,
|
||||||
|
params,
|
||||||
|
class_inv_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loss.update(loss.item(), data.shape[0])
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
|||||||
|
|
||||||
|
|
||||||
def split_diff(ann_dir, wav_dir, load_extra=True):
|
def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||||
|
|
||||||
train_sets = []
|
train_sets = []
|
||||||
if load_extra:
|
if load_extra:
|
||||||
train_sets.append(
|
train_sets.append(
|
||||||
@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
|
|||||||
|
|
||||||
|
|
||||||
def split_same(ann_dir, wav_dir, load_extra=True):
|
def split_same(ann_dir, wav_dir, load_extra=True):
|
||||||
|
|
||||||
train_sets = []
|
train_sets = []
|
||||||
if load_extra:
|
if load_extra:
|
||||||
train_sets.append(
|
train_sets.append(
|
@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
|
|||||||
|
|
||||||
|
|
||||||
def standardize_low_freq(
|
def standardize_low_freq(
|
||||||
data: List[types.FileAnnotation], class_of_interest: str,
|
data: List[types.FileAnnotation],
|
||||||
|
class_of_interest: str,
|
||||||
) -> List[types.FileAnnotation]:
|
) -> List[types.FileAnnotation]:
|
||||||
# address the issue of highly variable low frequency annotations
|
# address the issue of highly variable low frequency annotations
|
||||||
# this often happens for contstant frequency calls
|
# this often happens for contstant frequency calls
|
@ -1,56 +0,0 @@
|
|||||||
import pytorch_lightning as L
|
|
||||||
from torch import Tensor, optim
|
|
||||||
|
|
||||||
from batdetect2.models.typing import DetectionModel, ModelOutput
|
|
||||||
from batdetect2.train import losses
|
|
||||||
|
|
||||||
from batdetect2.train.dataset import TrainExample
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"LitDetectorModel",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LitDetectorModel(L.LightningModule):
|
|
||||||
model: DetectionModel
|
|
||||||
|
|
||||||
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
outputs: ModelOutput,
|
|
||||||
batch: TrainExample,
|
|
||||||
) -> Tensor:
|
|
||||||
detection_loss = losses.focal_loss(
|
|
||||||
outputs.detection_probs,
|
|
||||||
batch.detection_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
size_loss = losses.bbox_size_loss(
|
|
||||||
outputs.size_preds,
|
|
||||||
batch.size_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
|
||||||
classification_loss = losses.focal_loss(
|
|
||||||
outputs.class_probs,
|
|
||||||
batch.class_heatmap,
|
|
||||||
valid_mask=valid_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
return detection_loss + size_loss + classification_loss
|
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
|
|
||||||
outputs: ModelOutput = self.model(batch.spec)
|
|
||||||
loss = self.compute_loss(outputs, batch)
|
|
||||||
self.log("train_loss", loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
|
||||||
return [optimizer], [scheduler]
|
|
@ -1,7 +1,23 @@
|
|||||||
from typing import Optional
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"bbox_size_loss",
|
||||||
|
"compute_loss",
|
||||||
|
"focal_loss",
|
||||||
|
"mse_loss",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SizeLossConfig(BaseConfig):
|
||||||
|
weight: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
def bbox_size_loss(
|
def bbox_size_loss(
|
||||||
@ -17,6 +33,11 @@ def bbox_size_loss(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLossConfig(BaseConfig):
|
||||||
|
beta: float = 4
|
||||||
|
alpha: float = 2
|
||||||
|
|
||||||
|
|
||||||
def focal_loss(
|
def focal_loss(
|
||||||
pred: torch.Tensor,
|
pred: torch.Tensor,
|
||||||
gt: torch.Tensor,
|
gt: torch.Tensor,
|
||||||
@ -44,7 +65,7 @@ def focal_loss(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
pos_loss = pos_loss * weights
|
pos_loss = pos_loss * torch.tensor(weights)
|
||||||
# neg_loss = neg_loss*weights
|
# neg_loss = neg_loss*weights
|
||||||
|
|
||||||
if valid_mask is not None:
|
if valid_mask is not None:
|
||||||
@ -75,3 +96,71 @@ def mse_loss(
|
|||||||
else:
|
else:
|
||||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionLossConfig(BaseConfig):
|
||||||
|
weight: float = 1.0
|
||||||
|
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationLossConfig(BaseConfig):
|
||||||
|
weight: float = 2.0
|
||||||
|
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||||
|
class_weights: Optional[list[float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LossConfig(BaseConfig):
|
||||||
|
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
||||||
|
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
||||||
|
classification: ClassificationLossConfig = Field(
|
||||||
|
default_factory=ClassificationLossConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Losses(NamedTuple):
|
||||||
|
detection: torch.Tensor
|
||||||
|
size: torch.Tensor
|
||||||
|
classification: torch.Tensor
|
||||||
|
total: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
batch: TrainExample,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
conf: LossConfig,
|
||||||
|
class_weights: Optional[torch.Tensor] = None,
|
||||||
|
) -> Losses:
|
||||||
|
detection_loss = focal_loss(
|
||||||
|
outputs.detection_probs,
|
||||||
|
batch.detection_heatmap,
|
||||||
|
beta=conf.detection.focal.beta,
|
||||||
|
alpha=conf.detection.focal.alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
size_loss = bbox_size_loss(
|
||||||
|
outputs.size_preds,
|
||||||
|
batch.size_heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||||
|
classification_loss = focal_loss(
|
||||||
|
outputs.class_probs,
|
||||||
|
batch.class_heatmap,
|
||||||
|
weights=class_weights,
|
||||||
|
valid_mask=valid_mask,
|
||||||
|
beta=conf.classification.focal.beta,
|
||||||
|
alpha=conf.classification.focal.alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
detection_loss * conf.detection.weight
|
||||||
|
+ size_loss * conf.size.weight
|
||||||
|
+ classification_loss * conf.classification.weight
|
||||||
|
)
|
||||||
|
|
||||||
|
return Losses(
|
||||||
|
detection=detection_loss,
|
||||||
|
size=size_loss,
|
||||||
|
classification=classification_loss,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
@ -1,20 +1,28 @@
|
|||||||
"""Module for preprocessing data for training."""
|
"""Module for preprocessing data for training."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from multiprocessing import Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Sequence, Union
|
from typing import Callable, Optional, Sequence, Union
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.data.preprocessing import (
|
from batdetect2.preprocess import (
|
||||||
preprocess_audio_clip,
|
|
||||||
PreprocessingConfig,
|
PreprocessingConfig,
|
||||||
|
compute_spectrogram,
|
||||||
|
load_clip_audio,
|
||||||
|
)
|
||||||
|
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||||
|
from batdetect2.train.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_encoder,
|
||||||
|
build_sound_event_filter,
|
||||||
|
get_class_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
@ -22,31 +30,76 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
|
"preprocess_single_annotation",
|
||||||
|
"generate_train_example",
|
||||||
|
"TrainPreprocessingConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TrainPreprocessingConfig(BaseConfig):
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
class_mapper: ClassMapper,
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
target_config: Optional[TargetConfig] = None,
|
||||||
target_sigma: float = TARGET_SIGMA,
|
label_config: Optional[LabelConfig] = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Generate a training example."""
|
"""Generate a training example."""
|
||||||
spectrogram = preprocess_audio_clip(
|
config = TrainPreprocessingConfig(
|
||||||
clip_annotation.clip,
|
preprocessing=preprocessing_config or PreprocessingConfig(),
|
||||||
config=preprocessing_config,
|
target=target_config or TargetConfig(),
|
||||||
|
labels=label_config or LabelConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
wave = load_clip_audio(
|
||||||
|
clip_annotation.clip,
|
||||||
|
config=config.preprocessing.audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
spectrogram = compute_spectrogram(
|
||||||
|
wave,
|
||||||
|
config=config.preprocessing.spectrogram,
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_fn = build_sound_event_filter(
|
||||||
|
include=config.target.include,
|
||||||
|
exclude=config.target.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_events = [
|
||||||
|
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||||
|
]
|
||||||
|
|
||||||
|
encoder = build_encoder(
|
||||||
|
config.target.classes,
|
||||||
|
replacement_rules=config.target.replace,
|
||||||
|
)
|
||||||
|
class_names = get_class_names(config.target.classes)
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
selected_events,
|
||||||
spectrogram,
|
spectrogram,
|
||||||
class_mapper,
|
class_names,
|
||||||
target_sigma=target_sigma,
|
encoder,
|
||||||
|
target_sigma=config.labels.heatmaps.sigma,
|
||||||
|
position=config.labels.heatmaps.position,
|
||||||
|
time_scale=config.labels.heatmaps.time_scale,
|
||||||
|
frequency_scale=config.labels.heatmaps.frequency_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
dataset = xr.Dataset(
|
||||||
{
|
{
|
||||||
|
# NOTE: Need to rename the time dimension to avoid conflicts with
|
||||||
|
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||||
|
# the spectrogram and the heatmaps to the same temporal resolution
|
||||||
|
# as the waveform.
|
||||||
|
"audio": wave.rename({"time": "audio_time"}),
|
||||||
"spectrogram": spectrogram,
|
"spectrogram": spectrogram,
|
||||||
"detection": detection_heatmap,
|
"detection": detection_heatmap,
|
||||||
"class": class_heatmap,
|
"class": class_heatmap,
|
||||||
@ -56,9 +109,12 @@ def generate_train_example(
|
|||||||
|
|
||||||
return dataset.assign_attrs(
|
return dataset.assign_attrs(
|
||||||
title=f"Training example for {clip_annotation.uuid}",
|
title=f"Training example for {clip_annotation.uuid}",
|
||||||
preprocessing_configuration=preprocessing_config.model_dump_json(),
|
config=config.model_dump_json(),
|
||||||
target_sigma=target_sigma,
|
clip_annotation=clip_annotation.model_dump_json(
|
||||||
clip_annotation=clip_annotation.model_dump_json(),
|
exclude_none=True,
|
||||||
|
exclude_defaults=True,
|
||||||
|
exclude_unset=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -77,36 +133,54 @@ def save_to_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
|
||||||
"""Load configuration from file."""
|
|
||||||
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.is_file():
|
|
||||||
warnings.warn(f"Config file not found: {path}. Using default config.")
|
|
||||||
return PreprocessingConfig(**kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return PreprocessingConfig.model_validate_json(path.read_text())
|
|
||||||
except ValueError as e:
|
|
||||||
warnings.warn(
|
|
||||||
f"Failed to load config file: {e}. Using default config."
|
|
||||||
)
|
|
||||||
return PreprocessingConfig(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||||
return f"{clip_annotation.uuid}.nc"
|
return f"{clip_annotation.uuid}.nc"
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_annotations(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
output_dir: PathLike,
|
||||||
|
filename_fn: FilenameFn = _get_filename,
|
||||||
|
replace: bool = False,
|
||||||
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
|
target_config: Optional[TargetConfig] = None,
|
||||||
|
label_config: Optional[LabelConfig] = None,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Preprocess annotations and save to disk."""
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
if not output_dir.is_dir():
|
||||||
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
with Pool(max_workers) as pool:
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
pool.imap_unordered(
|
||||||
|
partial(
|
||||||
|
preprocess_single_annotation,
|
||||||
|
output_dir=output_dir,
|
||||||
|
filename_fn=filename_fn,
|
||||||
|
replace=replace,
|
||||||
|
preprocessing_config=preprocessing_config,
|
||||||
|
target_config=target_config,
|
||||||
|
label_config=label_config,
|
||||||
|
),
|
||||||
|
clip_annotations,
|
||||||
|
),
|
||||||
|
total=len(clip_annotations),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_single_annotation(
|
def preprocess_single_annotation(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
output_dir: PathLike,
|
output_dir: PathLike,
|
||||||
config: PreprocessingConfig,
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
class_mapper: ClassMapper,
|
target_config: Optional[TargetConfig] = None,
|
||||||
|
label_config: Optional[LabelConfig] = None,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
target_sigma: float = TARGET_SIGMA,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
@ -119,50 +193,16 @@ def preprocess_single_annotation(
|
|||||||
if path.is_file() and replace:
|
if path.is_file() and replace:
|
||||||
path.unlink()
|
path.unlink()
|
||||||
|
|
||||||
sample = generate_train_example(
|
try:
|
||||||
clip_annotation,
|
sample = generate_train_example(
|
||||||
class_mapper,
|
clip_annotation,
|
||||||
preprocessing_config=config,
|
preprocessing_config=preprocessing_config,
|
||||||
target_sigma=target_sigma,
|
target_config=target_config,
|
||||||
)
|
label_config=label_config,
|
||||||
|
)
|
||||||
|
except Exception as error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||||
|
) from error
|
||||||
|
|
||||||
save_to_file(sample, path)
|
save_to_file(sample, path)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_annotations(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
output_dir: PathLike,
|
|
||||||
class_mapper: ClassMapper,
|
|
||||||
target_sigma: float = TARGET_SIGMA,
|
|
||||||
filename_fn: FilenameFn = _get_filename,
|
|
||||||
replace: bool = False,
|
|
||||||
config: Optional[PreprocessingConfig] = None,
|
|
||||||
max_workers: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Preprocess annotations and save to disk."""
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
config = PreprocessingConfig()
|
|
||||||
|
|
||||||
if not output_dir.is_dir():
|
|
||||||
output_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
with Pool(max_workers) as pool:
|
|
||||||
list(
|
|
||||||
tqdm(
|
|
||||||
pool.imap_unordered(
|
|
||||||
partial(
|
|
||||||
preprocess_single_annotation,
|
|
||||||
output_dir=output_dir,
|
|
||||||
config=config,
|
|
||||||
class_mapper=class_mapper,
|
|
||||||
filename_fn=filename_fn,
|
|
||||||
replace=replace,
|
|
||||||
target_sigma=target_sigma,
|
|
||||||
),
|
|
||||||
clip_annotations,
|
|
||||||
),
|
|
||||||
total=len(clip_annotations),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
181
batdetect2/train/targets.py
Normal file
181
batdetect2/train/targets.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Optional, Set
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TargetConfig",
|
||||||
|
"load_target_config",
|
||||||
|
"build_encoder",
|
||||||
|
"build_decoder",
|
||||||
|
"filter_sound_event",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceConfig(BaseConfig):
|
||||||
|
"""Configuration for replacing tags."""
|
||||||
|
|
||||||
|
original: TagInfo
|
||||||
|
replacement: TagInfo
|
||||||
|
|
||||||
|
|
||||||
|
class TargetConfig(BaseConfig):
|
||||||
|
"""Configuration for target generation."""
|
||||||
|
|
||||||
|
classes: List[TagInfo] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
|
||||||
|
]
|
||||||
|
)
|
||||||
|
generic_class: Optional[TagInfo] = Field(
|
||||||
|
default_factory=lambda: TagInfo(key="class", value="Bat")
|
||||||
|
)
|
||||||
|
|
||||||
|
include: Optional[List[TagInfo]] = Field(
|
||||||
|
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
|
||||||
|
)
|
||||||
|
|
||||||
|
exclude: Optional[List[TagInfo]] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
TagInfo(key="class", value=""),
|
||||||
|
TagInfo(key="class", value=" "),
|
||||||
|
TagInfo(key="class", value="Unknown"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
replace: Optional[List[ReplaceConfig]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def build_sound_event_filter(
|
||||||
|
include: Optional[List[TagInfo]] = None,
|
||||||
|
exclude: Optional[List[TagInfo]] = None,
|
||||||
|
) -> Callable[[data.SoundEventAnnotation], bool]:
|
||||||
|
include_tags = (
|
||||||
|
{get_tag_from_info(tag) for tag in include} if include else None
|
||||||
|
)
|
||||||
|
exclude_tags = (
|
||||||
|
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
||||||
|
)
|
||||||
|
return partial(
|
||||||
|
filter_sound_event,
|
||||||
|
include=include_tags,
|
||||||
|
exclude=exclude_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tag_label(tag_info: TagInfo) -> str:
|
||||||
|
return tag_info.label if tag_info.label else tag_info.value
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||||
|
return sorted({get_tag_label(tag) for tag in classes})
|
||||||
|
|
||||||
|
|
||||||
|
def build_replacer(
|
||||||
|
rules: List[ReplaceConfig],
|
||||||
|
) -> Callable[[data.Tag], data.Tag]:
|
||||||
|
mapping = {
|
||||||
|
get_tag_from_info(rule.original): get_tag_from_info(rule.replacement)
|
||||||
|
for rule in rules
|
||||||
|
}
|
||||||
|
|
||||||
|
def replacer(tag: data.Tag) -> data.Tag:
|
||||||
|
return mapping.get(tag, tag)
|
||||||
|
|
||||||
|
return replacer
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder(
|
||||||
|
classes: List[TagInfo],
|
||||||
|
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
||||||
|
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||||
|
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||||
|
|
||||||
|
tag_mapping = {
|
||||||
|
tag: get_tag_label(tag_info)
|
||||||
|
for tag, tag_info in zip(target_tags, classes)
|
||||||
|
}
|
||||||
|
|
||||||
|
replacer = (
|
||||||
|
build_replacer(replacement_rules) if replacement_rules else lambda x: x
|
||||||
|
)
|
||||||
|
|
||||||
|
def encoder(
|
||||||
|
tags: Iterable[data.Tag],
|
||||||
|
) -> Optional[str]:
|
||||||
|
sanitized_tags = {replacer(tag) for tag in tags}
|
||||||
|
|
||||||
|
intersection = sanitized_tags & target_tags
|
||||||
|
|
||||||
|
if not intersection:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first = intersection.pop()
|
||||||
|
return tag_mapping[first]
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def build_decoder(
|
||||||
|
classes: List[TagInfo],
|
||||||
|
) -> Callable[[str], List[data.Tag]]:
|
||||||
|
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||||
|
tag_mapping = {
|
||||||
|
get_tag_label(tag_info): tag
|
||||||
|
for tag, tag_info in zip(target_tags, classes)
|
||||||
|
}
|
||||||
|
|
||||||
|
def decoder(label: str) -> List[data.Tag]:
|
||||||
|
tag = tag_mapping.get(label)
|
||||||
|
return [tag] if tag else []
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
|
def filter_sound_event(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
include: Optional[Set[data.Tag]] = None,
|
||||||
|
exclude: Optional[Set[data.Tag]] = None,
|
||||||
|
) -> bool:
|
||||||
|
tags = set(sound_event_annotation.tags)
|
||||||
|
|
||||||
|
if include is not None and not tags & include:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if exclude is not None and tags & exclude:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load_target_config(
|
||||||
|
path: Path, field: Optional[str] = None
|
||||||
|
) -> TargetConfig:
|
||||||
|
return load_config(path, schema=TargetConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SPECIES_LIST = [
|
||||||
|
"Barbastellus barbastellus",
|
||||||
|
"Eptesicus serotinus",
|
||||||
|
"Myotis alcathoe",
|
||||||
|
"Myotis bechsteinii",
|
||||||
|
"Myotis brandtii",
|
||||||
|
"Myotis daubentonii",
|
||||||
|
"Myotis mystacinus",
|
||||||
|
"Myotis nattereri",
|
||||||
|
"Nyctalus leisleri",
|
||||||
|
"Nyctalus noctula",
|
||||||
|
"Pipistrellus nathusii",
|
||||||
|
"Pipistrellus pipistrellus",
|
||||||
|
"Pipistrellus pygmaeus",
|
||||||
|
"Plecotus auritus",
|
||||||
|
"Plecotus austriacus",
|
||||||
|
"Rhinolophus ferrumequinum",
|
||||||
|
"Rhinolophus hipposideros",
|
||||||
|
]
|
@ -1,82 +1,68 @@
|
|||||||
from typing import Callable, NamedTuple, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
from lightning import LightningModule
|
||||||
from soundevent import data
|
from lightning.pytorch import Trainer
|
||||||
from torch.optim import Adam
|
from soundevent.data import PathLike
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.data.datasets import ClipAnnotationDataset
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.typing import DetectionModel
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"train",
|
||||||
|
"TrainerConfig",
|
||||||
|
"load_trainer_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TrainInputs(NamedTuple):
|
class TrainerConfig(BaseConfig):
|
||||||
spec: torch.Tensor
|
accelerator: str = "auto"
|
||||||
detection_heatmap: torch.Tensor
|
accumulate_grad_batches: int = 1
|
||||||
class_heatmap: torch.Tensor
|
deterministic: bool = True
|
||||||
size_heatmap: torch.Tensor
|
check_val_every_n_epoch: int = 1
|
||||||
|
devices: Union[str, int] = "auto"
|
||||||
|
enable_checkpointing: bool = True
|
||||||
|
gradient_clip_val: Optional[float] = None
|
||||||
|
limit_train_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_test_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_val_batches: Optional[Union[int, float]] = None
|
||||||
|
log_every_n_steps: Optional[int] = None
|
||||||
|
max_epochs: Optional[int] = None
|
||||||
|
min_epochs: Optional[int] = 100
|
||||||
|
max_steps: Optional[int] = None
|
||||||
|
min_steps: Optional[int] = None
|
||||||
|
max_time: Optional[str] = None
|
||||||
|
precision: Optional[str] = None
|
||||||
|
reload_dataloaders_every_n_epochs: Optional[int] = None
|
||||||
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
def train_loop(
|
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
||||||
model: DetectionModel,
|
return load_config(path, schema=TrainerConfig, field=field)
|
||||||
train_dataset: ClipAnnotationDataset[TrainInputs],
|
|
||||||
validation_dataset: ClipAnnotationDataset[TrainInputs],
|
|
||||||
device: Optional[torch.device] = None,
|
def train(
|
||||||
num_epochs: int = 100,
|
module: LightningModule,
|
||||||
learning_rate: float = 1e-4,
|
train_dataset: LabeledDataset,
|
||||||
|
trainer_config: Optional[TrainerConfig] = None,
|
||||||
|
dev_run: bool = False,
|
||||||
|
overfit_batches: bool = False,
|
||||||
|
profiler: Optional[str] = None,
|
||||||
):
|
):
|
||||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
trainer_config = trainer_config or TrainerConfig()
|
||||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
trainer = Trainer(
|
||||||
|
**trainer_config.model_dump(
|
||||||
model.to(device)
|
exclude_unset=True,
|
||||||
|
exclude_none=True,
|
||||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
),
|
||||||
scheduler = CosineAnnealingLR(
|
fast_dev_run=dev_run,
|
||||||
optimizer,
|
overfit_batches=overfit_batches,
|
||||||
num_epochs * len(train_loader),
|
profiler=profiler,
|
||||||
)
|
)
|
||||||
|
train_loader = DataLoader(
|
||||||
for epoch in range(num_epochs):
|
train_dataset,
|
||||||
train_loss = train_single_epoch(
|
batch_size=module.config.train.batch_size,
|
||||||
model,
|
shuffle=True,
|
||||||
train_loader,
|
num_workers=7,
|
||||||
optimizer,
|
)
|
||||||
device,
|
trainer.fit(module, train_dataloaders=train_loader)
|
||||||
scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train_single_epoch(
|
|
||||||
model: DetectionModel,
|
|
||||||
train_loader: DataLoader,
|
|
||||||
optimizer: Adam,
|
|
||||||
device: torch.device,
|
|
||||||
scheduler: CosineAnnealingLR,
|
|
||||||
):
|
|
||||||
model.train()
|
|
||||||
train_loss = tu.AverageMeter()
|
|
||||||
|
|
||||||
for batch in train_loader:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
spec = batch.spec.to(device)
|
|
||||||
detection_heatmap = batch.detection_heatmap.to(device)
|
|
||||||
class_heatmap = batch.class_heatmap.to(device)
|
|
||||||
size_heatmap = batch.size_heatmap.to(device)
|
|
||||||
|
|
||||||
outputs = model(spec)
|
|
||||||
|
|
||||||
loss = loss_fun(
|
|
||||||
outputs,
|
|
||||||
gt_det,
|
|
||||||
gt_size,
|
|
||||||
gt_class,
|
|
||||||
det_criterion,
|
|
||||||
params,
|
|
||||||
class_inv_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_loss.update(loss.item(), data.shape[0])
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, List, NamedTuple, Optional
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from typing import TypedDict
|
||||||
from typing import TypedDict
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -15,9 +14,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import NotRequired
|
from typing import NotRequired # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
@ -597,8 +595,7 @@ class FeatureExtractor(Protocol):
|
|||||||
self,
|
self,
|
||||||
prediction: Prediction,
|
prediction: Prediction,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> float:
|
) -> float: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetDict(TypedDict):
|
class DatasetDict(TypedDict):
|
||||||
|
@ -6,6 +6,8 @@ import librosa.core.spectrum
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
from . import wavfile
|
from . import wavfile
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -15,20 +17,44 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
def time_to_x_coords(
|
||||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
time_in_file: float,
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
) -> float:
|
||||||
|
nfft = np.floor(window_duration * samplerate) # int() uses floor
|
||||||
|
noverlap = np.floor(window_overlap * nfft)
|
||||||
|
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
|
||||||
|
|
||||||
|
|
||||||
# NOTE this is also defined in post_process
|
def x_coords_to_time(
|
||||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
x_pos: int,
|
||||||
nfft = np.floor(fft_win_length * sampling_rate)
|
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
) -> float:
|
||||||
|
n_fft = np.floor(window_duration * samplerate)
|
||||||
|
n_overlap = np.floor(window_overlap * n_fft)
|
||||||
|
n_step = n_fft - n_overlap
|
||||||
|
return ((x_pos * n_step) + n_overlap) / samplerate
|
||||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||||
|
|
||||||
|
|
||||||
|
def x_coord_to_sample(
|
||||||
|
x_pos: int,
|
||||||
|
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||||
|
) -> int:
|
||||||
|
n_fft = np.floor(window_duration * samplerate)
|
||||||
|
n_overlap = np.floor(window_overlap * n_fft)
|
||||||
|
n_step = n_fft - n_overlap
|
||||||
|
x_pos = int(x_pos / resize_factor)
|
||||||
|
return int((x_pos * n_step) + n_overlap)
|
||||||
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
@ -64,7 +90,7 @@ def generate_spectrogram(
|
|||||||
np.abs(
|
np.abs(
|
||||||
np.hanning(
|
np.hanning(
|
||||||
int(params["fft_win_length"] * sampling_rate)
|
int(params["fft_win_length"] * sampling_rate)
|
||||||
)
|
).astype(np.float32)
|
||||||
)
|
)
|
||||||
** 2
|
** 2
|
||||||
).sum()
|
).sum()
|
||||||
@ -74,7 +100,7 @@ def generate_spectrogram(
|
|||||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||||
spec = np.log1p(log_scaling * spec)
|
spec = np.log1p(log_scaling * spec)
|
||||||
elif params["spec_scale"] == "pcen":
|
elif params["spec_scale"] == "pcen":
|
||||||
spec = pcen(spec , sampling_rate)
|
spec = pcen(spec, sampling_rate)
|
||||||
|
|
||||||
elif params["spec_scale"] == "none":
|
elif params["spec_scale"] == "none":
|
||||||
pass
|
pass
|
||||||
@ -194,55 +220,118 @@ def load_audio(
|
|||||||
return sampling_rate, audio_raw
|
return sampling_rate, audio_raw
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram_width(
|
||||||
|
length: int,
|
||||||
|
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||||
|
) -> int:
|
||||||
|
n_fft = int(window_duration * samplerate)
|
||||||
|
n_overlap = int(window_overlap * n_fft)
|
||||||
|
n_step = n_fft - n_overlap
|
||||||
|
width = (length - n_overlap) // n_step
|
||||||
|
return int(width * resize_factor)
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(
|
def pad_audio(
|
||||||
audio_raw,
|
audio: np.ndarray,
|
||||||
fs,
|
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
ms,
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
overlap_perc,
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
resize_factor,
|
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||||
divide_factor,
|
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
|
||||||
fixed_width=None,
|
fixed_width: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
"""Pad audio to be evenly divisible by `divide_factor`.
|
||||||
# will be evenly divisible by `divide_factor`
|
|
||||||
# Also deals with very short audio clips and fixed_width during training
|
|
||||||
|
|
||||||
# This code could be clearer, clean up
|
This function pads the audio signal with zeros to ensure that the
|
||||||
nfft = int(ms * fs)
|
generated spectrogram length will be evenly divisible by `divide_factor`.
|
||||||
noverlap = int(overlap_perc * nfft)
|
This is important for the model to work correctly.
|
||||||
step = nfft - noverlap
|
|
||||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
|
||||||
spec_width = (audio_raw.shape[0] - noverlap) // step
|
|
||||||
spec_width_rs = spec_width * resize_factor
|
|
||||||
|
|
||||||
if fixed_width is not None and spec_width < fixed_width:
|
This `divide_factor` comes from the model architecture as it downscales
|
||||||
# too small
|
the spectrogram by this factor, so the input must be divisible by this
|
||||||
# used during training to ensure all the batches are the same size
|
integer number.
|
||||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
|
||||||
audio_raw = np.hstack(
|
Parameters
|
||||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
----------
|
||||||
|
audio : np.ndarray
|
||||||
|
The audio signal.
|
||||||
|
samplerate : int
|
||||||
|
The sampling rate of the audio signal.
|
||||||
|
window_size : float
|
||||||
|
The window size in seconds used for the spectrogram computation.
|
||||||
|
window_overlap : float
|
||||||
|
The overlap between windows in the spectrogram computation.
|
||||||
|
resize_factor : float
|
||||||
|
This factor is used to resize the spectrogram after the STFT
|
||||||
|
computation. Default is 0.5 which means that the spectrogram will be
|
||||||
|
reduced by half. Important to take into account for the final size of
|
||||||
|
the spectrogram.
|
||||||
|
divide_factor : int
|
||||||
|
The factor by which the spectrogram will be divided.
|
||||||
|
fixed_width : int, optional
|
||||||
|
If provided, the audio will be padded or cut so that the resulting
|
||||||
|
spectrogram width will be equal to this value.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
The padded audio signal.
|
||||||
|
"""
|
||||||
|
spec_width = compute_spectrogram_width(
|
||||||
|
audio.shape[0],
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if fixed_width:
|
||||||
|
target_samples = x_coord_to_sample(
|
||||||
|
fixed_width,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif fixed_width is not None and spec_width > fixed_width:
|
if spec_width < fixed_width:
|
||||||
# too big
|
# need to be at least min_size
|
||||||
# used during training to ensure all the batches are the same size
|
diff = target_samples - audio.shape[0]
|
||||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||||
audio_raw = audio_raw[:diff]
|
|
||||||
|
|
||||||
elif (
|
if spec_width > fixed_width:
|
||||||
spec_width_rs < min_size
|
return audio[:target_samples]
|
||||||
or (np.floor(spec_width_rs) % divide_factor) != 0
|
|
||||||
):
|
return audio
|
||||||
# need to be at least min_size
|
|
||||||
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
min_width = int(divide_factor / resize_factor)
|
||||||
div_amt = np.maximum(1, div_amt)
|
|
||||||
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
|
if spec_width < min_width:
|
||||||
diff = target_size * step + noverlap - audio_raw.shape[0]
|
target_samples = x_coord_to_sample(
|
||||||
audio_raw = np.hstack(
|
min_width,
|
||||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
)
|
)
|
||||||
|
diff = target_samples - audio.shape[0]
|
||||||
|
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||||
|
|
||||||
return audio_raw
|
if (spec_width % divide_factor) == 0:
|
||||||
|
return audio
|
||||||
|
|
||||||
|
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
|
||||||
|
target_samples = x_coord_to_sample(
|
||||||
|
target_width,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
|
)
|
||||||
|
diff = target_samples - audio.shape[0]
|
||||||
|
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||||
|
@ -8,6 +8,11 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
try:
|
||||||
|
from numpy.exceptions import AxisError
|
||||||
|
except ImportError:
|
||||||
|
from numpy import AxisError # type: ignore
|
||||||
|
|
||||||
import batdetect2.detector.compute_features as feats
|
import batdetect2.detector.compute_features as feats
|
||||||
import batdetect2.detector.post_process as pp
|
import batdetect2.detector.post_process as pp
|
||||||
import batdetect2.utils.audio_utils as au
|
import batdetect2.utils.audio_utils as au
|
||||||
@ -80,6 +85,7 @@ def load_model(
|
|||||||
model_path: str = DEFAULT_MODEL_PATH,
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
device: Union[torch.device, str, None] = None,
|
device: Union[torch.device, str, None] = None,
|
||||||
|
weights_only: bool = True,
|
||||||
) -> Tuple[DetectionModel, ModelParameters]:
|
) -> Tuple[DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
@ -100,7 +106,11 @@ def load_model(
|
|||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
raise FileNotFoundError("Model file not found.")
|
raise FileNotFoundError("Model file not found.")
|
||||||
|
|
||||||
net_params = torch.load(model_path, map_location=device)
|
net_params = torch.load(
|
||||||
|
model_path,
|
||||||
|
map_location=device,
|
||||||
|
weights_only=weights_only,
|
||||||
|
)
|
||||||
|
|
||||||
params = net_params["params"]
|
params = net_params["params"]
|
||||||
|
|
||||||
@ -242,7 +252,7 @@ def format_single_result(
|
|||||||
)
|
)
|
||||||
class_name = class_names[np.argmax(class_overall)]
|
class_name = class_names[np.argmax(class_overall)]
|
||||||
annotations = get_annotations_from_preds(predictions, class_names)
|
annotations = get_annotations_from_preds(predictions, class_names)
|
||||||
except (np.AxisError, ValueError):
|
except (AxisError, ValueError):
|
||||||
# No detections
|
# No detections
|
||||||
class_overall = np.zeros(len(class_names))
|
class_overall = np.zeros(len(class_names))
|
||||||
class_name = "None"
|
class_name = "None"
|
||||||
@ -399,7 +409,7 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
|
|
||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: float,
|
sampling_rate: int,
|
||||||
params: SpectrogramParameters,
|
params: SpectrogramParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[float, torch.Tensor]:
|
) -> Tuple[float, torch.Tensor]:
|
||||||
@ -617,7 +627,7 @@ def process_spectrogram(
|
|||||||
|
|
||||||
def _process_audio_array(
|
def _process_audio_array(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: float,
|
sampling_rate: int,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -738,9 +748,7 @@ def process_file(
|
|||||||
|
|
||||||
# Get original sampling rate
|
# Get original sampling rate
|
||||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||||
orig_samp_rate = file_samp_rate * float(
|
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||||
config.get("time_expansion", 1.0) or 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
|
@ -417,7 +417,9 @@ def plot_confusion_matrix(
|
|||||||
cm_norm = cm.sum(1)
|
cm_norm = cm.sum(1)
|
||||||
|
|
||||||
valid_inds = np.where(cm_norm > 0)[0]
|
valid_inds = np.where(cm_norm > 0)[0]
|
||||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
cm[valid_inds, :] = (
|
||||||
|
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||||
|
)
|
||||||
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -155,9 +155,9 @@ class InteractivePlotter:
|
|||||||
|
|
||||||
# draw bounding box around call
|
# draw bounding box around call
|
||||||
self.ax[1].patches[0].remove()
|
self.ax[1].patches[0].remove()
|
||||||
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
spec_width_orig = self.spec_slices[self.current_id].shape[
|
||||||
1.0 + 2.0 * self.spec_pad
|
1
|
||||||
)
|
] / (1.0 + 2.0 * self.spec_pad)
|
||||||
xx = w_diff + self.spec_pad * spec_width_orig
|
xx = w_diff + self.spec_pad * spec_width_orig
|
||||||
ww = spec_width_orig
|
ww = spec_width_orig
|
||||||
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
||||||
@ -183,7 +183,9 @@ class InteractivePlotter:
|
|||||||
round(self.call_info[self.current_id]["start_time"], 3)
|
round(self.call_info[self.current_id]["start_time"], 3)
|
||||||
)
|
)
|
||||||
+ ", prob="
|
+ ", prob="
|
||||||
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
+ str(
|
||||||
|
round(self.call_info[self.current_id]["det_prob"], 3)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.ax[0].set_xlabel(info_str)
|
self.ax[0].set_xlabel(info_str)
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ Functions
|
|||||||
`write`: Write a numpy array as a WAV file.
|
`write`: Write a numpy array as a WAV file.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -156,7 +157,6 @@ def read(filename, mmap=False):
|
|||||||
fid = open(filename, "rb")
|
fid = open(filename, "rb")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# some files seem to have the size recorded in the header greater than
|
# some files seem to have the size recorded in the header greater than
|
||||||
# the actual file size.
|
# the actual file size.
|
||||||
fid.seek(0, os.SEEK_END)
|
fid.seek(0, os.SEEK_END)
|
||||||
|
743
notebooks/Augmentations.ipynb
Normal file
743
notebooks/Augmentations.ipynb
Normal file
File diff suppressed because one or more lines are too long
1166
notebooks/Migrations.ipynb
Normal file
1166
notebooks/Migrations.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -6,11 +6,11 @@
|
|||||||
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
|
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:20.598611Z",
|
"iopub.execute_input": "2024-11-19T17:33:02.699871Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:20.596274Z",
|
"iopub.status.busy": "2024-11-19T17:33:02.699590Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:20.670888Z",
|
"iopub.status.idle": "2024-11-19T17:33:02.710312Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:20.668193Z",
|
"shell.execute_reply": "2024-11-19T17:33:02.709798Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:20.598423Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -25,11 +25,11 @@
|
|||||||
"id": "326c5432-94e6-4abf-a332-fe902559461b",
|
"id": "326c5432-94e6-4abf-a332-fe902559461b",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:20.676278Z",
|
"iopub.execute_input": "2024-11-19T17:33:02.711324Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:20.675545Z",
|
"iopub.status.busy": "2024-11-19T17:33:02.711067Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:25.872556Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.092380Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:25.871725Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.091830Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:20.676206Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -37,7 +37,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -45,26 +45,35 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"from typing import List, Optional\n",
|
"from typing import List, Optional\n",
|
||||||
|
"import torch\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import pytorch_lightning as pl\n",
|
"import pytorch_lightning as pl\n",
|
||||||
"from soundevent import data\n",
|
"from batdetect2.train.modules import DetectorModel\n",
|
||||||
"from torch.utils.data import DataLoader\n",
|
|
||||||
"\n",
|
|
||||||
"from batdetect2.data.labels import ClassMapper\n",
|
|
||||||
"from batdetect2.models.detectors import DetectorModel\n",
|
|
||||||
"from batdetect2.train.augmentations import (\n",
|
"from batdetect2.train.augmentations import (\n",
|
||||||
" add_echo,\n",
|
" add_echo,\n",
|
||||||
" select_random_subclip,\n",
|
" select_random_subclip,\n",
|
||||||
" warp_spectrogram,\n",
|
" warp_spectrogram,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
|
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
|
||||||
"from batdetect2.train.preprocess import PreprocessingConfig"
|
"from batdetect2.train.preprocess import PreprocessingConfig\n",
|
||||||
|
"from soundevent import data\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from soundevent.types import ClassMapper\n",
|
||||||
|
"from torch.utils.data import DataLoader"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "fa202af2-5c0d-4b5d-91a3-097ef5cd4272",
|
"id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
|
||||||
|
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
|
||||||
|
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
|
||||||
|
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## Training Datasets"
|
"## Training Datasets"
|
||||||
]
|
]
|
||||||
@ -75,11 +84,11 @@
|
|||||||
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
|
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:25.874255Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.093487Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:25.873473Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.092990Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:25.912952Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.121636Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:25.911844Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.121143Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:25.874206Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -93,11 +102,11 @@
|
|||||||
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
|
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:25.914456Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.122685Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:25.914027Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.122270Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:25.954939Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.151386Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:25.953906Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.150788Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:25.914410Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -111,11 +120,11 @@
|
|||||||
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
|
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:25.956758Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.152327Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:25.956260Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.152060Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:25.997664Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.184041Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:25.996074Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.183372Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:25.956705Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -129,11 +138,11 @@
|
|||||||
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
|
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:26.003195Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.186393Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:26.002783Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.186117Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:26.054400Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.220175Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:26.053294Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.219322Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:26.003158Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -152,11 +161,11 @@
|
|||||||
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
|
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:26.056060Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.221653Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:26.055706Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.221242Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:26.103227Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.260977Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:26.102190Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.260375Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:26.056025Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -168,7 +177,6 @@
|
|||||||
" \"Myotis mystacinus\",\n",
|
" \"Myotis mystacinus\",\n",
|
||||||
" \"Pipistrellus pipistrellus\",\n",
|
" \"Pipistrellus pipistrellus\",\n",
|
||||||
" \"Rhinolophus ferrumequinum\",\n",
|
" \"Rhinolophus ferrumequinum\",\n",
|
||||||
" \"social\",\n",
|
|
||||||
" ]\n",
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
|
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
|
||||||
@ -197,11 +205,11 @@
|
|||||||
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
|
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:26.104877Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.262337Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:26.104538Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.261775Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:26.159676Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.309793Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:26.157914Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.309216Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:26.104843Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -215,11 +223,11 @@
|
|||||||
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
|
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:26.162346Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.310695Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:26.161885Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.310438Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:26.374668Z",
|
"iopub.status.idle": "2024-11-19T17:33:09.366636Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:26.373691Z",
|
"shell.execute_reply": "2024-11-19T17:33:09.366059Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:26.162305Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -229,7 +237,6 @@
|
|||||||
"text": [
|
"text": [
|
||||||
"GPU available: False, used: False\n",
|
"GPU available: False, used: False\n",
|
||||||
"TPU available: False, using: 0 TPU cores\n",
|
"TPU available: False, using: 0 TPU cores\n",
|
||||||
"IPU available: False, using: 0 IPUs\n",
|
|
||||||
"HPU available: False, using: 0 HPUs\n"
|
"HPU available: False, using: 0 HPUs\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -248,11 +255,11 @@
|
|||||||
"id": "0b86d49d-3314-4257-94f5-f964855be385",
|
"id": "0b86d49d-3314-4257-94f5-f964855be385",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:26.375918Z",
|
"iopub.execute_input": "2024-11-19T17:33:09.367499Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:26.375632Z",
|
"iopub.status.busy": "2024-11-19T17:33:09.367242Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:28.829650Z",
|
"iopub.status.idle": "2024-11-19T17:33:10.811300Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:28.828219Z",
|
"shell.execute_reply": "2024-11-19T17:33:10.809823Z",
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:26.375889Z"
|
"shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -261,37 +268,67 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
" | Name | Type | Params\n",
|
" | Name | Type | Params | Mode \n",
|
||||||
"------------------------------------------------\n",
|
"--------------------------------------------------------\n",
|
||||||
"0 | feature_extractor | Net2DFast | 119 K \n",
|
"0 | feature_extractor | Net2DFast | 119 K | train\n",
|
||||||
"1 | classifier | Conv2d | 54 \n",
|
"1 | classifier | Conv2d | 54 | train\n",
|
||||||
"2 | bbox | Conv2d | 18 \n",
|
"2 | bbox | Conv2d | 18 | train\n",
|
||||||
"------------------------------------------------\n",
|
"--------------------------------------------------------\n",
|
||||||
"119 K Trainable params\n",
|
"119 K Trainable params\n",
|
||||||
"448 Non-trainable params\n",
|
"448 Non-trainable params\n",
|
||||||
"119 K Total params\n",
|
"119 K Total params\n",
|
||||||
"0.480 Total estimated model params size (MB)\n"
|
"0.480 Total estimated model params size (MB)\n",
|
||||||
|
"32 Modules in train mode\n",
|
||||||
|
"0 Modules in eval mode\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.59it/s, v_num=13]"
|
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
|
||||||
|
"class props shape torch.Size([3, 5, 128, 512])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"ename": "RuntimeError",
|
||||||
"output_type": "stream",
|
"evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
|
||||||
"text": [
|
"output_type": "error",
|
||||||
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
|
"traceback": [
|
||||||
]
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
},
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||||
{
|
"Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
"name": "stdout",
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
"output_type": "stream",
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
|
||||||
"text": [
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||||
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.54it/s, v_num=13]\n"
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||||
|
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -301,15 +338,14 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": null,
|
||||||
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
|
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:28.832222Z",
|
"iopub.status.busy": "2024-11-19T17:33:10.811729Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:28.831642Z",
|
"iopub.status.idle": "2024-11-19T17:33:10.811955Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:29.000595Z",
|
"shell.execute_reply": "2024-11-19T17:33:10.811858Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:28.998078Z",
|
"shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:28.832157Z"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -319,44 +355,54 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": null,
|
||||||
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
|
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:27:29.004279Z",
|
"iopub.status.busy": "2024-11-19T17:33:10.812924Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:27:29.003486Z",
|
"iopub.status.idle": "2024-11-19T17:33:10.813260Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:27:29.595626Z",
|
"shell.execute_reply": "2024-11-19T17:33:10.813104Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:27:29.594734Z",
|
"shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
|
||||||
"shell.execute_reply.started": "2024-07-16T00:27:29.004200Z"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"predictions = detector.compute_clip_predictions(clip_annotation.clip)"
|
"spec = detector.compute_spectrogram(clip_annotation.clip)\n",
|
||||||
|
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": null,
|
||||||
|
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"_, ax= plt.subplots(figsize=(15, 5))\n",
|
||||||
|
"spec.plot(ax=ax, add_colorbar=False)\n",
|
||||||
|
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
|
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
"iopub.execute_input": "2024-07-16T00:28:47.178783Z",
|
"iopub.status.busy": "2024-11-19T17:33:10.815603Z",
|
||||||
"iopub.status.busy": "2024-07-16T00:28:47.178143Z",
|
"iopub.status.idle": "2024-11-19T17:33:10.816065Z",
|
||||||
"iopub.status.idle": "2024-07-16T00:28:47.246613Z",
|
"shell.execute_reply": "2024-11-19T17:33:10.815894Z",
|
||||||
"shell.execute_reply": "2024-07-16T00:28:47.245496Z",
|
"shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
|
||||||
"shell.execute_reply.started": "2024-07-16T00:28:47.178729Z"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Num predicted soundevents: 50\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
|
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
|
||||||
]
|
]
|
||||||
@ -364,7 +410,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "d3883c04-d91a-4d1d-b677-196c0179dde1",
|
"id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
@ -386,7 +432,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.18"
|
"version": "3.12.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
140
pyproject.toml
140
pyproject.toml
@ -1,103 +1,97 @@
|
|||||||
[tool]
|
|
||||||
rye = { dev-dependencies = [
|
|
||||||
"ipykernel>=6.29.4",
|
|
||||||
"setuptools>=69.5.1",
|
|
||||||
"pytest>=8.1.1",
|
|
||||||
] }
|
|
||||||
[tool.pdm]
|
|
||||||
[tool.pdm.dev-dependencies]
|
|
||||||
dev = [
|
|
||||||
"pytest>=7.2.2",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "batdetect2"
|
name = "batdetect2"
|
||||||
version = "1.0.8"
|
version = "1.1.1"
|
||||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
|
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"librosa>=0.10.1",
|
"click>=8.1.7",
|
||||||
"matplotlib>=3.7.1",
|
"librosa>=0.10.1",
|
||||||
"numpy>=1.23.5",
|
"matplotlib>=3.7.1",
|
||||||
"pandas>=1.5.3",
|
"numpy>=1.23.5",
|
||||||
"scikit-learn>=1.2.2",
|
"pandas>=1.5.3",
|
||||||
"scipy>=1.10.1",
|
"scikit-learn>=1.2.2",
|
||||||
"torch>=1.13.1",
|
"scipy>=1.10.1",
|
||||||
"torchaudio",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchvision",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.0.1",
|
"torchvision>=0.14.0",
|
||||||
"click>=8.1.7",
|
"soundevent[audio,geometry,plot]>=2.3",
|
||||||
"netcdf4>=1.6.5",
|
"click>=8.1.7",
|
||||||
"tqdm>=4.66.2",
|
"netcdf4>=1.6.5",
|
||||||
"pytorch-lightning>=2.2.2",
|
"tqdm>=4.66.2",
|
||||||
"cf-xarray>=0.9.0",
|
"pytorch-lightning>=2.2.2",
|
||||||
"onnx>=1.16.0",
|
"cf-xarray>=0.9.0",
|
||||||
"lightning[extra]>=2.2.2",
|
"onnx>=1.16.0",
|
||||||
"tensorboard>=2.16.2",
|
"lightning[extra]>=2.2.2",
|
||||||
|
"tensorboard>=2.16.2",
|
||||||
|
"omegaconf>=2.3.0",
|
||||||
|
"pyyaml>=6.0.2",
|
||||||
|
"hydra-core>=1.3.2",
|
||||||
|
"numba>=0.60",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9,<3.13"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "CC-by-nc-4" }
|
license = { text = "CC-by-nc-4" }
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
"Natural Language :: English",
|
"Natural Language :: English",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||||
]
|
]
|
||||||
keywords = [
|
keywords = [
|
||||||
"bat",
|
"bat",
|
||||||
"echolocation",
|
"echolocation",
|
||||||
"deep learning",
|
"deep learning",
|
||||||
"audio",
|
"audio",
|
||||||
"machine learning",
|
"machine learning",
|
||||||
"classification",
|
"classification",
|
||||||
"detection",
|
"detection",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["pdm-pep517>=1.0.0"]
|
requires = ["hatchling"]
|
||||||
build-backend = "pdm.pep517.api"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
batdetect2 = "batdetect2.cli:cli"
|
batdetect2 = "batdetect2.cli:cli"
|
||||||
|
|
||||||
[tool.black]
|
[tool.uv]
|
||||||
line-length = 79
|
dev-dependencies = [
|
||||||
|
"debugpy>=1.8.8",
|
||||||
[tool.isort]
|
"hypothesis>=6.118.7",
|
||||||
profile = "black"
|
"pytest>=7.2.2",
|
||||||
line_length = 79
|
"ruff>=0.7.3",
|
||||||
|
"ipykernel>=6.29.4",
|
||||||
|
"setuptools>=69.5.1",
|
||||||
|
"basedpyright>=1.28.4",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
|
target-version = "py39"
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[tool.ruff.format]
|
||||||
module = [
|
docstring-code-format = true
|
||||||
"librosa",
|
docstring-code-line-length = 79
|
||||||
"pandas",
|
|
||||||
]
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[tool.pylsp-mypy]
|
[tool.ruff.lint]
|
||||||
enabled = false
|
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||||
live_mode = true
|
|
||||||
strict = true
|
|
||||||
|
|
||||||
[tool.pydocstyle]
|
[tool.ruff.lint.pydocstyle]
|
||||||
convention = "numpy"
|
convention = "numpy"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = [
|
include = ["batdetect2", "tests"]
|
||||||
"bat_detect",
|
|
||||||
"tests",
|
|
||||||
]
|
|
||||||
venvPath = "."
|
venvPath = "."
|
||||||
venv = ".venv"
|
venv = ".venv"
|
||||||
|
pythonVersion = "3.9"
|
||||||
|
pythonPlatform = "All"
|
||||||
|
@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu
|
|||||||
import batdetect2.utils.audio_utils as au
|
import batdetect2.utils.audio_utils as au
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"audio_path", type=str, help="Input directory for audio"
|
"audio_path", type=str, help="Input directory for audio"
|
||||||
@ -65,7 +64,9 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
# load uk data - special case
|
# load uk data - special case
|
||||||
print("\nLoading:", args["uk_split"], "\n")
|
print("\nLoading:", args["uk_split"], "\n")
|
||||||
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
|
dataset_name = (
|
||||||
|
"uk_" + args["uk_split"]
|
||||||
|
) # should be uk_diff, or uk_same
|
||||||
datasets, _ = ts.get_train_test_data(
|
datasets, _ = ts.get_train_test_data(
|
||||||
args["ann_file"],
|
args["ann_file"],
|
||||||
args["audio_path"],
|
args["audio_path"],
|
||||||
|
@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
||||||
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
||||||
@ -143,7 +142,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# run model and filter detections so only keep ones in relevant time range
|
# run model and filter detections so only keep ones in relevant time range
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
|
results = du.process_file(
|
||||||
|
args_cmd["audio_file"], model, run_config, device
|
||||||
|
)
|
||||||
pred_anns = filter_anns(
|
pred_anns = filter_anns(
|
||||||
results["pred_dict"]["annotation"],
|
results["pred_dict"]["annotation"],
|
||||||
args_cmd["start_time"],
|
args_cmd["start_time"],
|
||||||
|
@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
parser.add_argument(
|
||||||
|
"audio_file", type=str, help="Path to input audio file"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"model_path", type=str, help="Path to trained BatDetect model"
|
"model_path", type=str, help="Path to trained BatDetect model"
|
||||||
)
|
)
|
||||||
|
@ -198,7 +198,6 @@ def save_summary_image(
|
|||||||
)
|
)
|
||||||
ii = 0
|
ii = 0
|
||||||
for row in ax:
|
for row in ax:
|
||||||
|
|
||||||
if type(row) != np.ndarray:
|
if type(row) != np.ndarray:
|
||||||
row = np.array([row])
|
row = np.array([row])
|
||||||
|
|
||||||
@ -215,7 +214,9 @@ def save_summary_image(
|
|||||||
)
|
)
|
||||||
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
||||||
col.set_xticks([])
|
col.set_xticks([])
|
||||||
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
|
col.title.set_text(
|
||||||
|
str(ii + 1) + " " + species_names[order[ii]]
|
||||||
|
)
|
||||||
col.tick_params(axis="both", which="major", labelsize=7)
|
col.tick_params(axis="both", which="major", labelsize=7)
|
||||||
ii += 1
|
ii += 1
|
||||||
|
|
||||||
|
109
tests/conftest.py
Normal file
109
tests/conftest.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import soundfile as sf
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_data_dir() -> Path:
|
||||||
|
pkg_dir = Path(__file__).parent.parent
|
||||||
|
example_data_dir = pkg_dir / "example_data"
|
||||||
|
assert example_data_dir.exists()
|
||||||
|
return example_data_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_audio_dir(example_data_dir: Path) -> Path:
|
||||||
|
example_audio_dir = example_data_dir / "audio"
|
||||||
|
assert example_audio_dir.exists()
|
||||||
|
return example_audio_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_anns_dir(example_data_dir: Path) -> Path:
|
||||||
|
example_anns_dir = example_data_dir / "anns"
|
||||||
|
assert example_anns_dir.exists()
|
||||||
|
return example_anns_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||||
|
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||||
|
assert len(audio_files) == 3
|
||||||
|
return audio_files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data_dir() -> Path:
|
||||||
|
dir = Path(__file__).parent / "data"
|
||||||
|
assert dir.exists()
|
||||||
|
return dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def contrib_dir(data_dir) -> Path:
|
||||||
|
dir = data_dir / "contrib"
|
||||||
|
assert dir.exists()
|
||||||
|
return dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def wav_factory(tmp_path: Path):
|
||||||
|
def _wav_factory(
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
duration: float = 0.3,
|
||||||
|
channels: int = 1,
|
||||||
|
samplerate: int = 441_000,
|
||||||
|
bit_depth: int = 16,
|
||||||
|
) -> Path:
|
||||||
|
path = path or tmp_path / f"{uuid.uuid4()}.wav"
|
||||||
|
frames = int(samplerate * duration)
|
||||||
|
shape = (frames, channels)
|
||||||
|
subtype = f"PCM_{bit_depth}"
|
||||||
|
|
||||||
|
if bit_depth == 16:
|
||||||
|
dtype = np.int16
|
||||||
|
elif bit_depth == 32:
|
||||||
|
dtype = np.int32
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported bit depth: {bit_depth}")
|
||||||
|
|
||||||
|
wav = np.random.uniform(
|
||||||
|
low=np.iinfo(dtype).min,
|
||||||
|
high=np.iinfo(dtype).max,
|
||||||
|
size=shape,
|
||||||
|
).astype(dtype)
|
||||||
|
sf.write(str(path), wav, samplerate, subtype=subtype)
|
||||||
|
return path
|
||||||
|
|
||||||
|
return _wav_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def recording_factory(wav_factory: Callable[..., Path]):
|
||||||
|
def _recording_factory(
|
||||||
|
tags: Optional[list[data.Tag]] = None,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
recording_id: Optional[uuid.UUID] = None,
|
||||||
|
duration: float = 1,
|
||||||
|
channels: int = 1,
|
||||||
|
samplerate: int = 256_000,
|
||||||
|
time_expansion: float = 1,
|
||||||
|
) -> data.Recording:
|
||||||
|
path = path or wav_factory(
|
||||||
|
duration=duration,
|
||||||
|
channels=channels,
|
||||||
|
samplerate=samplerate,
|
||||||
|
)
|
||||||
|
return data.Recording.from_file(
|
||||||
|
path=path,
|
||||||
|
uuid=recording_id or uuid.uuid4(),
|
||||||
|
time_expansion=time_expansion,
|
||||||
|
tags=tags or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
return _recording_factory
|
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
Binary file not shown.
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
@ -1,14 +1,13 @@
|
|||||||
"""Test bat detect module API."""
|
"""Test bat detect module API."""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
|
||||||
@ -267,7 +266,6 @@ def test_process_file_with_spec_slices():
|
|||||||
assert len(results["spec_slices"]) == len(detections)
|
assert len(results["spec_slices"]) == len(detections)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_with_empty_predictions_does_not_fail(
|
def test_process_file_with_empty_predictions_does_not_fail(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user