mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
71 Commits
f63307757c
...
98c6da6d42
Author | SHA1 | Date | |
---|---|---|---|
![]() |
98c6da6d42 | ||
![]() |
fe8d044af2 | ||
![]() |
62fa38557e | ||
![]() |
213b6dfd29 | ||
![]() |
bfa6049adc | ||
![]() |
e383a33cbf | ||
![]() |
7689580a24 | ||
![]() |
1338ae7431 | ||
![]() |
c2c4ac53fd | ||
![]() |
ff00da9a9a | ||
![]() |
22cf47ed39 | ||
![]() |
d9f7304a0f | ||
![]() |
451093f2da | ||
![]() |
30d3a2c92e | ||
![]() |
c17a25fa75 | ||
![]() |
29f7862153 | ||
![]() |
dbc3ff9364 | ||
![]() |
150305a273 | ||
![]() |
904e8f23ea | ||
![]() |
48e009fa9d | ||
![]() |
f7d6516550 | ||
![]() |
e9e1f7ce2f | ||
![]() |
113223be02 | ||
![]() |
35c916482c | ||
![]() |
09bd8cf423 | ||
![]() |
f6cdd4e87e | ||
![]() |
9cf159efff | ||
![]() |
36c90a600f | ||
![]() |
1f0fb14d89 | ||
![]() |
ee884da8b0 | ||
![]() |
6a9e33c729 | ||
![]() |
2100a3e483 | ||
![]() |
1d3cd2e305 | ||
![]() |
d5753b95bb | ||
![]() |
69f59ff559 | ||
![]() |
1a11174bc4 | ||
![]() |
c5c9476e52 | ||
![]() |
270b3f212d | ||
![]() |
f61d1d8c72 | ||
![]() |
4627ddd739 | ||
![]() |
3477d7b5b4 | ||
![]() |
394c66a2ee | ||
![]() |
d085b3212c | ||
![]() |
c393c5c29b | ||
![]() |
3c22ff28a7 | ||
![]() |
7dc28695b2 | ||
![]() |
505cca2dea | ||
![]() |
7906842a16 | ||
![]() |
a4b22d6590 | ||
![]() |
25e0a53ad1 | ||
![]() |
039c002796 | ||
![]() |
c97a87b2a4 | ||
![]() |
d93d8284d0 | ||
![]() |
697b5dbddb | ||
![]() |
d5bf8f5ad8 | ||
![]() |
fcbccbe012 | ||
![]() |
4917641e2c | ||
![]() |
39c3918103 | ||
![]() |
1ac3808fee | ||
![]() |
9e0ad7fd78 | ||
![]() |
95bb0985e7 | ||
![]() |
cb088359ae | ||
![]() |
c5030123aa | ||
![]() |
1c1fbd8019 | ||
![]() |
c65fe1c9f9 | ||
![]() |
d05bec880a | ||
![]() |
8597ef0a1c | ||
![]() |
2d8a7b67f8 | ||
![]() |
68351d2224 | ||
![]() |
3f34164028 | ||
![]() |
d84b7795f6 |
8
.bumpversion.cfg
Normal file
8
.bumpversion.cfg
Normal file
@ -0,0 +1,8 @@
|
||||
[bumpversion]
|
||||
current_version = 1.1.1
|
||||
commit = True
|
||||
tag = True
|
||||
|
||||
[bumpversion:file:batdetect2/__init__.py]
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
35
.github/workflows/python-package.yml
vendored
35
.github/workflows/python-package.yml
vendored
@ -1,34 +1,29 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
||||
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
pip install .
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
- name: Test with pytest
|
||||
run: uv run pytest
|
||||
|
41
.github/workflows/python-publish.yml
vendored
41
.github/workflows/python-publish.yml
vendored
@ -1,11 +1,3 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
@ -17,23 +9,22 @@ permissions:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -103,13 +103,11 @@ experiments/*
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
|
||||
# Bump2version
|
||||
.bumpversion.cfg
|
||||
|
||||
# DO Include
|
||||
!batdetect2_notebook.ipynb
|
||||
!batdetect2/models/checkpoints/*.pth.tar
|
||||
!tests/data/*.wav
|
||||
!notebooks/*.ipynb
|
||||
!tests/data/**/*.wav
|
||||
notebooks/lightning_logs
|
||||
example_data/preprocessed
|
||||
|
@ -1 +1,6 @@
|
||||
__version__ = '1.0.8'
|
||||
import logging
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
__version__ = "1.1.1"
|
||||
|
@ -1,9 +1,13 @@
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.train import train
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
"detect",
|
||||
"data",
|
||||
"train",
|
||||
]
|
||||
|
||||
|
||||
|
22
batdetect2/cli/ascii.py
Normal file
22
batdetect2/cli/ascii.py
Normal file
@ -0,0 +1,22 @@
|
||||
BATDETECT_ASCII_ART = """ .
|
||||
=#%: .%%#
|
||||
:%%%: .%%%%.
|
||||
%%%%.-===::%%%%*
|
||||
=%%%%+++++++%%%#.
|
||||
-: .%%%#====+++#%%# .-
|
||||
.+***= . =++. : .=*+#%*= :***.
|
||||
=+****+++==:%+#=+% *##%%%%*=##*#**-=
|
||||
++***+**+=: ##.. +##%%########**++
|
||||
.++*****#*+- :*:++ ##%#%%%%%####**++
|
||||
.++***+**++++- :#%%%%%####*##***+=
|
||||
.+++***+#+++*########%%%##%#+*****++:
|
||||
.=++++++*+++##%##%%####%%##*:+****+=
|
||||
=++++++====*#%%#%###%%###- +***+++.
|
||||
.+*++++= =+==##########= :****++.
|
||||
=++*+:. .:=#####= .++**++-
|
||||
.****: . -+**++=
|
||||
*###= .****==
|
||||
.#*#- **#*:
|
||||
-### -*##.
|
||||
+*= *#*
|
||||
"""
|
@ -1,18 +1,14 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
]
|
||||
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
@ -25,3 +21,4 @@ BatDetect2 - Detection and Classification
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
@ -1,14 +1,11 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
@ -114,10 +111,9 @@ def detect(
|
||||
):
|
||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||
save_results_to_file(results, results_path)
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
except (RuntimeError, ValueError, LookupError, EOFError) as err:
|
||||
error_files.append(audio_file)
|
||||
click.secho(f"Error processing file!: {err}", fg="red")
|
||||
raise err
|
||||
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
|
||||
|
||||
click.echo(f"\nResults saved to: {ann_dir}")
|
||||
|
||||
|
40
batdetect2/cli/data.py
Normal file
40
batdetect2/cli/data.py
Normal file
@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
__all__ = ["data"]
|
||||
|
||||
|
||||
@cli.group()
|
||||
def data(): ...
|
||||
|
||||
|
||||
@data.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
"--field",
|
||||
type=str,
|
||||
help="If the dataset info is in a nested field please specify here.",
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help="The base directory to which all recording and annotations paths are relative to.",
|
||||
)
|
||||
def summary(
|
||||
dataset_config: Path,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
base_dir = base_dir or Path.cwd()
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config, field=field, base_dir=base_dir
|
||||
)
|
||||
print(f"Number of annotated clips: {len(dataset.clip_annotations)}")
|
188
batdetect2/cli/train.py
Normal file
188
batdetect2/cli/train.py
Normal file
@ -0,0 +1,188 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.preprocess import (
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
load_label_config,
|
||||
load_target_config,
|
||||
preprocess_annotations,
|
||||
)
|
||||
|
||||
__all__ = ["train"]
|
||||
|
||||
|
||||
@cli.group()
|
||||
def train(): ...
|
||||
|
||||
|
||||
@train.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.argument(
|
||||
"output",
|
||||
type=click.Path(),
|
||||
)
|
||||
@click.option(
|
||||
"--dataset-field",
|
||||
type=str,
|
||||
help=(
|
||||
"Specifies the key to access the dataset information within the "
|
||||
"dataset configuration file, if the information is nested inside a "
|
||||
"dictionary. If the dataset information is at the top level of the "
|
||||
"config file, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"The main directory where your audio recordings and annotation "
|
||||
"files are stored. This helps the program find your data, "
|
||||
"especially if the paths in your dataset configuration file "
|
||||
"are relative."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the preprocessing configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
"within the preprocessing configuration file, specify the key "
|
||||
"here to access them. If the preprocessing settings are at the "
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--label-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the label generation configuration file. This file "
|
||||
"contains settings for how to create labels from your "
|
||||
"annotations, which the model uses to learn."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--label-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the label generation settings are inside a nested dictionary "
|
||||
"within the label configuration file, specify the key here. If "
|
||||
"the settings are at the top level, leave this blank."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the training target configuration file. This file "
|
||||
"specifies what sounds the model should learn to predict."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the target settings are inside a nested dictionary "
|
||||
"within the target configuration file, specify the key here. "
|
||||
"If the settings are at the top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--force",
|
||||
is_flag=True,
|
||||
help=(
|
||||
"If a preprocessed file already exists, this option tells the "
|
||||
"program to overwrite it with the new preprocessed data. Use "
|
||||
"this if you want to re-do the preprocessing even if the files "
|
||||
"already exist."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum number of computer cores to use when processing "
|
||||
"your audio data. Using more cores can speed up the preprocessing, "
|
||||
"but don't use more than your computer has available. By default, "
|
||||
"the program will use all available cores."
|
||||
),
|
||||
)
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
base_dir: Optional[Path] = None,
|
||||
preprocess_config: Optional[Path] = None,
|
||||
target_config: Optional[Path] = None,
|
||||
label_config: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
num_workers: Optional[int] = None,
|
||||
target_config_field: Optional[str] = None,
|
||||
preprocess_config_field: Optional[str] = None,
|
||||
label_config_field: Optional[str] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
):
|
||||
output = Path(output)
|
||||
base_dir = base_dir or Path.cwd()
|
||||
|
||||
preprocess = (
|
||||
load_preprocessing_config(
|
||||
preprocess_config,
|
||||
field=preprocess_config_field,
|
||||
)
|
||||
if preprocess_config
|
||||
else None
|
||||
)
|
||||
|
||||
target = (
|
||||
load_target_config(
|
||||
target_config,
|
||||
field=target_config_field,
|
||||
)
|
||||
if target_config
|
||||
else None
|
||||
)
|
||||
|
||||
label = (
|
||||
load_label_config(
|
||||
label_config,
|
||||
field=label_config_field,
|
||||
)
|
||||
if label_config
|
||||
else None
|
||||
)
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
field=dataset_field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
if not output.exists():
|
||||
output.mkdir(parents=True)
|
||||
|
||||
preprocess_annotations(
|
||||
dataset.clip_annotations,
|
||||
output_dir=output,
|
||||
replace=force,
|
||||
preprocessing_config=preprocess,
|
||||
label_config=label,
|
||||
target_config=target,
|
||||
max_workers=num_workers,
|
||||
)
|
0
batdetect2/compat/__init__.py
Normal file
0
batdetect2/compat/__init__.py
Normal file
@ -9,15 +9,16 @@ import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2 import types
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
"load_annotation_project",
|
||||
"load_file_annotation",
|
||||
"annotation_to_sound_event",
|
||||
]
|
||||
|
||||
SPECIES_TAG_KEY = "species"
|
||||
@ -195,18 +196,30 @@ def annotation_to_sound_event(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key=label_key, value=annotation.label),
|
||||
data.Tag(key=event_key, value=annotation.event),
|
||||
data.Tag(key=individual_key, value=str(annotation.individual)),
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=annotation.label,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(event_key),
|
||||
value=annotation.event,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(individual_key),
|
||||
value=str(annotation.individual),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: PathLike = Path.cwd(),
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
audio_dir = audio_dir or Path.cwd()
|
||||
|
||||
full_path = Path(audio_dir) / file_annotation.id
|
||||
|
||||
if not full_path.exists():
|
||||
@ -215,6 +228,12 @@ def file_annotation_to_clip(
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
@ -241,7 +260,11 @@ def file_annotation_to_clip_annotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key), value=file_annotation.label
|
||||
)
|
||||
],
|
||||
sound_events=[
|
||||
annotation_to_sound_event(
|
||||
annotation,
|
||||
@ -281,52 +304,3 @@ def list_file_annotations(path: PathLike) -> List[Path]:
|
||||
"""List all annotations in a directory."""
|
||||
path = Path(path)
|
||||
return [file for file in path.glob("*.json")]
|
||||
|
||||
|
||||
def load_annotation_project(
|
||||
path: PathLike,
|
||||
name: Optional[str] = None,
|
||||
audio_dir: PathLike = Path.cwd(),
|
||||
) -> data.AnnotationProject:
|
||||
"""Convert annotations to annotation project."""
|
||||
paths = list_file_annotations(path)
|
||||
|
||||
if name is None:
|
||||
name = str(path)
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
file_annotation = load_file_annotation(p)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(
|
||||
file_annotation,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(
|
||||
file_annotation_to_clip_annotation(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
file_annotation_to_annotation_task(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=name,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
151
batdetect2/compat/params.py
Normal file
151
batdetect2/compat/params.py
Normal file
@ -0,0 +1,151 @@
|
||||
from batdetect2.preprocess import (
|
||||
AmplitudeScaleConfig,
|
||||
AudioConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
PreprocessingConfig,
|
||||
ResampleConfig,
|
||||
Scales,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||
from batdetect2.terms import TagInfo
|
||||
from batdetect2.train.preprocess import (
|
||||
HeatmapsConfig,
|
||||
TargetConfig,
|
||||
TrainPreprocessingConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_spectrogram_scale(scale: str) -> Scales:
|
||||
if scale == "pcen":
|
||||
return PcenScaleConfig()
|
||||
if scale == "log":
|
||||
return LogScaleConfig()
|
||||
return AmplitudeScaleConfig()
|
||||
|
||||
|
||||
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||
return PreprocessingConfig(
|
||||
audio=AudioConfig(
|
||||
resample=ResampleConfig(
|
||||
samplerate=params["target_samp_rate"],
|
||||
mode="poly",
|
||||
),
|
||||
scale=params["scale_raw_audio"],
|
||||
center=params["scale_raw_audio"],
|
||||
duration=None,
|
||||
),
|
||||
spectrogram=SpectrogramConfig(
|
||||
stft=STFTConfig(
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
window_fn="hann",
|
||||
),
|
||||
frequencies=FrequencyConfig(
|
||||
min_freq=params["min_freq"],
|
||||
max_freq=params["max_freq"],
|
||||
),
|
||||
scale=get_spectrogram_scale(params["spec_scale"]),
|
||||
denoise=params["denoise_spec_avg"],
|
||||
size=SpecSizeConfig(
|
||||
height=params["spec_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
),
|
||||
max_scale=params["max_scale_spec"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_training_preprocessing_config(
|
||||
params: dict,
|
||||
) -> TrainPreprocessingConfig:
|
||||
generic = params["generic_class"][0]
|
||||
preprocessing = get_preprocessing_config(params)
|
||||
|
||||
freq_bin_width, time_bin_width = get_spectrogram_resolution(
|
||||
preprocessing.spectrogram
|
||||
)
|
||||
|
||||
return TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing,
|
||||
target=TargetConfig(
|
||||
classes=[
|
||||
TagInfo(key="class", value=class_name, label=class_name)
|
||||
for class_name in params["class_names"]
|
||||
],
|
||||
generic_class=TagInfo(
|
||||
key="class",
|
||||
value=generic,
|
||||
label=generic,
|
||||
),
|
||||
include=[
|
||||
TagInfo(key="event", value=event)
|
||||
for event in params["events_of_interest"]
|
||||
],
|
||||
exclude=[
|
||||
TagInfo(key="class", value=value)
|
||||
for value in params["classes_to_ignore"]
|
||||
],
|
||||
),
|
||||
heatmaps=HeatmapsConfig(
|
||||
position="bottom-left",
|
||||
time_scale=1 / time_bin_width,
|
||||
frequency_scale=1 / freq_bin_width,
|
||||
sigma=params["target_sigma"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 'standardize_classs_names_ip',
|
||||
# 'convert_to_genus',
|
||||
# 'genus_mapping',
|
||||
# 'standardize_classs_names',
|
||||
# 'genus_names',
|
||||
|
||||
# ['data_dir',
|
||||
# 'ann_dir',
|
||||
# 'train_split',
|
||||
# 'model_name',
|
||||
# 'num_filters',
|
||||
# 'experiment',
|
||||
# 'model_file_name',
|
||||
# 'op_im_dir',
|
||||
# 'op_im_dir_test',
|
||||
# 'notes',
|
||||
# 'spec_divide_factor',
|
||||
# 'detection_overlap',
|
||||
# 'ignore_start_end',
|
||||
# 'detection_threshold',
|
||||
# 'nms_kernel_size',
|
||||
# 'nms_top_k_per_sec',
|
||||
# 'aug_prob',
|
||||
# 'augment_at_train',
|
||||
# 'augment_at_train_combine',
|
||||
# 'echo_max_delay',
|
||||
# 'stretch_squeeze_delta',
|
||||
# 'mask_max_time_perc',
|
||||
# 'mask_max_freq_perc',
|
||||
# 'spec_amp_scaling',
|
||||
# 'aug_sampling_rates',
|
||||
# 'train_loss',
|
||||
# 'det_loss_weight',
|
||||
# 'size_loss_weight',
|
||||
# 'class_loss_weight',
|
||||
# 'individual_loss_weight',
|
||||
# 'emb_dim',
|
||||
# 'lr',
|
||||
# 'batch_size',
|
||||
# 'num_workers',
|
||||
# 'num_epochs',
|
||||
# 'num_eval_epochs',
|
||||
# 'device',
|
||||
# 'save_test_image_during_train',
|
||||
# 'save_test_image_after_train',
|
||||
# 'train_sets',
|
||||
# 'test_sets',
|
||||
# 'class_inv_freq',
|
||||
# 'ip_height']
|
35
batdetect2/configs.py
Normal file
35
batdetect2/configs.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from soundevent.data import PathLike
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def get_object_field(obj: dict, field: str) -> Any:
|
||||
if "." not in field:
|
||||
return obj[field]
|
||||
|
||||
field, rest = field.split(".", 1)
|
||||
subobj = obj[field]
|
||||
return get_object_field(subobj, rest)
|
||||
|
||||
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T],
|
||||
field: Optional[str] = None,
|
||||
) -> T:
|
||||
with open(path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
if field:
|
||||
config = get_object_field(config, field)
|
||||
|
||||
return schema.model_validate(config)
|
@ -0,0 +1,14 @@
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.data import load_dataset, load_dataset_from_config
|
||||
from batdetect2.data.types import Dataset
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"Dataset",
|
||||
"load_annotated_dataset",
|
||||
"load_dataset",
|
||||
"load_dataset_from_config",
|
||||
]
|
36
batdetect2/data/annotations.py
Normal file
36
batdetect2/data/annotations.py
Normal file
@ -0,0 +1,36 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotationFile",
|
||||
"AnnotationFormats",
|
||||
"BatDetect2AnnotationFile",
|
||||
"BatDetect2AnnotationFiles",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2AnnotationFiles(BaseConfig):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
path: Path
|
||||
|
||||
|
||||
class BatDetect2AnnotationFile(BaseConfig):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
path: Path
|
||||
|
||||
|
||||
class AOEFAnnotationFile(BaseConfig):
|
||||
format: Literal["aoef"] = "aoef"
|
||||
path: Path
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
BatDetect2AnnotationFiles,
|
||||
BatDetect2AnnotationFile,
|
||||
AOEFAnnotationFile,
|
||||
]
|
||||
|
||||
|
55
batdetect2/data/annotations/__init__.py
Normal file
55
batdetect2/data/annotations/__init__.py
Normal file
@ -0,0 +1,55 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.aeof import (
|
||||
AOEFAnnotations,
|
||||
load_aoef_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.batdetect2_files import (
|
||||
BatDetect2FilesAnnotations,
|
||||
load_batdetect2_files_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.batdetect2_merged import (
|
||||
BatDetect2MergedAnnotations,
|
||||
load_batdetect2_merged_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
"load_annotated_dataset",
|
||||
"AnnotatedDataset",
|
||||
"AOEFAnnotations",
|
||||
"BatDetect2FilesAnnotations",
|
||||
"BatDetect2MergedAnnotations",
|
||||
"AnnotationFormats",
|
||||
]
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
]
|
||||
|
||||
|
||||
def load_annotated_dataset(
|
||||
dataset: AnnotatedDataset,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
if isinstance(dataset, AOEFAnnotations):
|
||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||
|
||||
if isinstance(dataset, BatDetect2MergedAnnotations):
|
||||
return load_batdetect2_merged_annotated_dataset(
|
||||
dataset, base_dir=base_dir
|
||||
)
|
||||
|
||||
if isinstance(dataset, BatDetect2FilesAnnotations):
|
||||
return load_batdetect2_files_annotated_dataset(
|
||||
dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
|
37
batdetect2/data/annotations/aeof.py
Normal file
37
batdetect2/data/annotations/aeof.py
Normal file
@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"load_aoef_annotated_dataset",
|
||||
]
|
||||
|
||||
|
||||
class AOEFAnnotations(AnnotatedDataset):
|
||||
format: Literal["aoef"] = "aoef"
|
||||
annotations_path: Path
|
||||
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
loaded = io.load(path, audio_dir=audio_dir)
|
||||
|
||||
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
|
||||
raise ValueError(
|
||||
f"The AOEF file at {path} does not contain a set of annotations"
|
||||
)
|
||||
|
||||
return loaded
|
80
batdetect2/data/annotations/batdetect2_files.py
Normal file
80
batdetect2/data/annotations/batdetect2_files.py
Normal file
@ -0,0 +1,80 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
file_annotation_to_annotation_task,
|
||||
file_annotation_to_clip,
|
||||
file_annotation_to_clip_annotation,
|
||||
list_file_annotations,
|
||||
load_file_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_batdetect2_files_annotated_dataset",
|
||||
"BatDetect2FilesAnnotations",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2FilesAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
annotations_dir: Path
|
||||
|
||||
|
||||
def load_batdetect2_files_annotated_dataset(
|
||||
dataset: BatDetect2FilesAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationProject:
|
||||
"""Convert annotations to annotation project."""
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_dir
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
paths = list_file_annotations(path)
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
file_annotation = load_file_annotation(p)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(
|
||||
file_annotation,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(
|
||||
file_annotation_to_clip_annotation(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
file_annotation_to_annotation_task(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
64
batdetect2/data/annotations/batdetect2_merged.py
Normal file
64
batdetect2/data/annotations/batdetect2_merged.py
Normal file
@ -0,0 +1,64 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
FileAnnotation,
|
||||
file_annotation_to_annotation_task,
|
||||
file_annotation_to_clip,
|
||||
file_annotation_to_clip_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2MergedAnnotations",
|
||||
"load_batdetect2_merged_annotated_dataset",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2MergedAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
annotations_path: Path
|
||||
|
||||
|
||||
def load_batdetect2_merged_annotated_dataset(
|
||||
dataset: BatDetect2MergedAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationProject:
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
content = json.loads(Path(path).read_text())
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for ann in content:
|
||||
try:
|
||||
ann = FileAnnotation.model_validate(ann)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||
tasks.append(file_annotation_to_annotation_task(ann, clip))
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
304
batdetect2/data/annotations/legacy.py
Normal file
304
batdetect2/data/annotations/legacy.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""Compatibility functions between old and new data structures."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2 import types
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
]
|
||||
|
||||
SPECIES_TAG_KEY = "species"
|
||||
ECHOLOCATION_EVENT = "Echolocation"
|
||||
UNKNOWN_CLASS = "__UNKNOWN__"
|
||||
|
||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
|
||||
ClassFn = Callable[[data.Recording], int]
|
||||
|
||||
IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
||||
|
||||
|
||||
def get_recording_class_name(recording: data.Recording) -> str:
|
||||
"""Get the class name for a recording."""
|
||||
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
|
||||
if tag is None:
|
||||
return UNKNOWN_CLASS
|
||||
return tag.value
|
||||
|
||||
|
||||
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
|
||||
"""Get the notes for a ClipAnnotation."""
|
||||
all_notes = [
|
||||
*annotation.notes,
|
||||
*annotation.clip.recording.notes,
|
||||
]
|
||||
messages = [note.message for note in all_notes if note.message is not None]
|
||||
return "\n".join(messages)
|
||||
|
||||
|
||||
def convert_to_annotation_group(
|
||||
annotation: data.ClipAnnotation,
|
||||
class_mapper: ClassMapper,
|
||||
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||
class_fn: ClassFn = lambda _: 0,
|
||||
individual_fn: IndividualFn = lambda _: 0,
|
||||
) -> types.AudioLoaderAnnotationGroup:
|
||||
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
|
||||
recording = annotation.clip.recording
|
||||
|
||||
start_times = []
|
||||
end_times = []
|
||||
low_freqs = []
|
||||
high_freqs = []
|
||||
class_ids = []
|
||||
x_inds = []
|
||||
y_inds = []
|
||||
individual_ids = []
|
||||
annotations: List[types.Annotation] = []
|
||||
class_id_file = class_fn(recording)
|
||||
|
||||
for sound_event in annotation.sound_events:
|
||||
geometry = sound_event.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
class_id = class_mapper.transform(sound_event) or -1
|
||||
event = event_fn(sound_event) or ""
|
||||
individual_id = individual_fn(sound_event) or -1
|
||||
|
||||
start_times.append(start_time)
|
||||
end_times.append(end_time)
|
||||
low_freqs.append(low_freq)
|
||||
high_freqs.append(high_freq)
|
||||
class_ids.append(class_id)
|
||||
individual_ids.append(individual_id)
|
||||
|
||||
# NOTE: This will be computed later so we just put a placeholder
|
||||
# here for now.
|
||||
x_inds.append(0)
|
||||
y_inds.append(0)
|
||||
|
||||
annotations.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": 1.0,
|
||||
"det_prob": 1.0,
|
||||
"individual": "0",
|
||||
"event": event,
|
||||
"class_id": class_id, # type: ignore
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(recording.path),
|
||||
"duration": recording.duration,
|
||||
"issues": False,
|
||||
"file_path": str(recording.path),
|
||||
"time_exp": recording.time_expansion,
|
||||
"class_name": get_recording_class_name(recording),
|
||||
"notes": get_annotation_notes(annotation),
|
||||
"annotated": True,
|
||||
"start_times": np.array(start_times),
|
||||
"end_times": np.array(end_times),
|
||||
"low_freqs": np.array(low_freqs),
|
||||
"high_freqs": np.array(high_freqs),
|
||||
"class_ids": np.array(class_ids),
|
||||
"x_inds": np.array(x_inds),
|
||||
"y_inds": np.array(y_inds),
|
||||
"individual_ids": np.array(individual_ids),
|
||||
"annotation": annotations,
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
"""Annotation class to hold batdetect annotations."""
|
||||
|
||||
label: str = Field(alias="class")
|
||||
event: str
|
||||
individual: int = 0
|
||||
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
|
||||
|
||||
class FileAnnotation(BaseModel):
|
||||
"""FileAnnotation class to hold batdetect annotations for a file."""
|
||||
|
||||
id: str
|
||||
duration: float
|
||||
time_exp: float = 1
|
||||
|
||||
label: str = Field(alias="class_name")
|
||||
|
||||
annotation: List[Annotation]
|
||||
|
||||
annotated: bool = False
|
||||
issues: bool = False
|
||||
notes: str = ""
|
||||
|
||||
|
||||
def load_file_annotation(path: PathLike) -> FileAnnotation:
|
||||
"""Load annotation from batdetect format."""
|
||||
path = Path(path)
|
||||
return FileAnnotation.model_validate_json(path.read_text())
|
||||
|
||||
|
||||
def annotation_to_sound_event(
|
||||
annotation: Annotation,
|
||||
recording: data.Recording,
|
||||
label_key: str = "class",
|
||||
event_key: str = "event",
|
||||
individual_key: str = "individual",
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Convert annotation to sound event annotation."""
|
||||
sound_event = data.SoundEvent(
|
||||
uuid=uuid.uuid5(
|
||||
NAMESPACE,
|
||||
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
||||
),
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
annotation.start_time,
|
||||
annotation.low_freq,
|
||||
annotation.end_time,
|
||||
annotation.high_freq,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
return data.SoundEventAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=annotation.label,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(event_key),
|
||||
value=annotation.event,
|
||||
),
|
||||
data.Tag(
|
||||
term=data.term_from_key(individual_key),
|
||||
value=str(annotation.individual),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
audio_dir = audio_dir or Path.cwd()
|
||||
|
||||
full_path = Path(audio_dir) / file_annotation.id
|
||||
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"File {full_path} not found.")
|
||||
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key),
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip_annotation(
|
||||
file_annotation: FileAnnotation,
|
||||
clip: data.Clip,
|
||||
label_key: str = "class",
|
||||
event_key: str = "event",
|
||||
individual_key: str = "individual",
|
||||
) -> data.ClipAnnotation:
|
||||
"""Convert file annotation to clip annotation."""
|
||||
notes = []
|
||||
if file_annotation.notes:
|
||||
notes.append(data.Note(message=file_annotation.notes))
|
||||
|
||||
return data.ClipAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=data.term_from_key(label_key), value=file_annotation.label
|
||||
)
|
||||
],
|
||||
sound_events=[
|
||||
annotation_to_sound_event(
|
||||
annotation,
|
||||
clip.recording,
|
||||
label_key=label_key,
|
||||
event_key=event_key,
|
||||
individual_key=individual_key,
|
||||
)
|
||||
for annotation in file_annotation.annotation
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_annotation_task(
|
||||
file_annotation: FileAnnotation,
|
||||
clip: data.Clip,
|
||||
) -> data.AnnotationTask:
|
||||
status_badges = []
|
||||
|
||||
if file_annotation.issues:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.rejected)
|
||||
)
|
||||
elif file_annotation.annotated:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.completed)
|
||||
)
|
||||
|
||||
return data.AnnotationTask(
|
||||
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
|
||||
clip=clip,
|
||||
status_badges=status_badges,
|
||||
)
|
||||
|
||||
|
||||
def list_file_annotations(path: PathLike) -> List[Path]:
|
||||
"""List all annotations in a directory."""
|
||||
path = Path(path)
|
||||
return [file for file in path.glob("*.json")]
|
41
batdetect2/data/annotations/types.py
Normal file
41
batdetect2/data/annotations/types.py
Normal file
@ -0,0 +1,41 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"BatDetect2MergedAnnotations",
|
||||
]
|
||||
|
||||
|
||||
class AnnotatedDataset(BaseConfig):
|
||||
"""Represents a single, cohesive source of audio recordings and annotations.
|
||||
|
||||
A source typically groups recordings originating from a specific context,
|
||||
such as a single project, site, deployment, or recordist. All audio files
|
||||
belonging to a source should be located within a single directory,
|
||||
specified by `audio_dir`.
|
||||
|
||||
Annotations associated with these recordings are defined by the
|
||||
`annotations` field, which supports various formats (e.g., AOEF files,
|
||||
specific CSV
|
||||
structures).
|
||||
Crucially, file paths referenced within the annotation data *must* be
|
||||
relative to the `audio_dir`. This ensures that the dataset definition
|
||||
remains portable across different systems and base directories.
|
||||
|
||||
Attributes:
|
||||
name: A unique identifier for this data source.
|
||||
description: Detailed information about the source, including recording
|
||||
methods, annotation procedures, equipment used, potential biases,
|
||||
or any important caveats for users.
|
||||
audio_dir: The file system path to the directory containing the audio
|
||||
recordings for this source.
|
||||
"""
|
||||
|
||||
name: str
|
||||
audio_dir: Path
|
||||
description: str = ""
|
||||
|
||||
|
37
batdetect2/data/data.py
Normal file
37
batdetect2/data/data.py
Normal file
@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
from batdetect2.data.annotations import load_annotated_dataset
|
||||
from batdetect2.data.types import Dataset
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"load_dataset_from_config",
|
||||
]
|
||||
|
||||
|
||||
def load_dataset(
|
||||
dataset: Dataset,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
clip_annotations = []
|
||||
for source in dataset.sources:
|
||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||
clip_annotations.extend(annotated_source.clip_annotations)
|
||||
return data.AnnotationSet(clip_annotations=clip_annotations)
|
||||
|
||||
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
config = load_config(
|
||||
path=path,
|
||||
schema=Dataset,
|
||||
field=field,
|
||||
)
|
||||
return load_dataset(config, base_dir=base_dir)
|
@ -1,33 +0,0 @@
|
||||
from typing import Callable, Generic, Iterable, List, TypeVar
|
||||
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
__all__ = [
|
||||
"ClipDataset",
|
||||
]
|
||||
|
||||
|
||||
E = TypeVar("E")
|
||||
|
||||
|
||||
class ClipDataset(Dataset, Generic[E]):
|
||||
clips: List[data.Clip]
|
||||
|
||||
transform: Callable[[data.Clip], E]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clips: Iterable[data.Clip],
|
||||
transform: Callable[[data.Clip], E],
|
||||
name: str = "ClipDataset",
|
||||
):
|
||||
self.clips = list(clips)
|
||||
self.transform = transform
|
||||
self.name = name
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
|
||||
def __getitem__(self, idx: int) -> E:
|
||||
return self.transform(self.clips[idx])
|
@ -1,392 +0,0 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import BaseModel, Field
|
||||
from scipy.signal import resample_poly
|
||||
from soundevent import audio, data, arrays
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
__all__ = [
|
||||
"PreprocessingConfig",
|
||||
"preprocess_audio_clip",
|
||||
]
|
||||
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
SCALE_RAW_AUDIO = False
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
DEFAULT_DURATION = 1
|
||||
SPEC_HEIGHT = 128
|
||||
SPEC_WIDTH = 256
|
||||
SPEC_SCALE = "pcen"
|
||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseModel):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
|
||||
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
||||
|
||||
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||
|
||||
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||
|
||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||
|
||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||
|
||||
spec_scale: str = Field(default=SPEC_SCALE)
|
||||
|
||||
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
||||
|
||||
max_scale_spec: bool = MAX_SCALE_SPEC
|
||||
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
spec_height: int = SPEC_HEIGHT
|
||||
|
||||
spec_time_period: float = SPEC_TIME_PERIOD
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
path: Union[str, Path],
|
||||
) -> "PreprocessingConfig":
|
||||
"""Load configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path
|
||||
Path to the configuration file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PreprocessingConfig
|
||||
The configuration loaded from the file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the configuration file does not exist.
|
||||
pydantic.ValidationError
|
||||
If the configuration file is invalid.
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Config file not found: {path}")
|
||||
|
||||
return cls.model_validate_json(path.read_text())
|
||||
|
||||
def to_file(self, path: Union[str, Path]) -> None:
|
||||
"""Save configuration to a file."""
|
||||
path = Path(path)
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
path.write_text(self.model_dump_json())
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
config: PreprocessingConfig = PreprocessingConfig(),
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
config
|
||||
Configuration for preprocessing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
|
||||
"""
|
||||
wav = load_clip_audio(
|
||||
clip,
|
||||
target_sampling_rate=config.target_samplerate,
|
||||
scale=config.scale_audio,
|
||||
)
|
||||
|
||||
spec = compute_spectrogram(
|
||||
wav,
|
||||
fft_win_length=config.fft_win_length,
|
||||
fft_overlap=config.fft_overlap,
|
||||
max_freq=config.max_freq,
|
||||
min_freq=config.min_freq,
|
||||
spec_scale=config.spec_scale,
|
||||
denoise_spec_avg=config.denoise_spec_avg,
|
||||
max_scale_spec=config.max_scale_spec,
|
||||
)
|
||||
|
||||
if config.duration is not None:
|
||||
spec = adjust_spec_duration(clip, spec, config.duration)
|
||||
|
||||
duration = arrays.get_dim_width(spec, dim="time")
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(np.ceil(duration / config.spec_time_period)),
|
||||
frequency=config.spec_height,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def adjust_spec_duration(
|
||||
clip: data.Clip,
|
||||
spec: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
current_duration = clip.end_time - clip.start_time
|
||||
|
||||
if current_duration == duration:
|
||||
return spec
|
||||
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
start=clip.start_time,
|
||||
stop=clip.start_time + duration,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
start=clip.start_time,
|
||||
stop=clip.start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = SCALE_RAW_AUDIO,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
||||
|
||||
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
||||
|
||||
if scale:
|
||||
wav = ops.center(wav)
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
step = arrays.get_dim_step(wav, dim="time")
|
||||
original_samplerate = int(1 / step)
|
||||
|
||||
if original_samplerate == target_samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
gcd = np.gcd(original_samplerate, target_samplerate)
|
||||
resampled = resample_poly(
|
||||
wav.values,
|
||||
target_samplerate // gcd,
|
||||
original_samplerate // gcd,
|
||||
axis=time_axis,
|
||||
)
|
||||
|
||||
resampled_times = np.linspace(
|
||||
start,
|
||||
stop + step,
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
resampled_times,
|
||||
samplerate=target_samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
)
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
||||
fft_overlap: float = FFT_OVERLAP,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
spec_scale: str = SPEC_SCALE,
|
||||
denoise_spec_avg: bool = True,
|
||||
max_scale_spec: bool = False,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
spec = gen_mag_spectrogram(
|
||||
wav,
|
||||
window_len=fft_win_length,
|
||||
overlap_perc=fft_overlap,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
).astype(dtype)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=spec_scale)
|
||||
|
||||
if denoise_spec_avg:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if max_scale_spec:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
|
||||
return spec.astype(dtype)
|
||||
|
||||
|
||||
def gen_mag_spectrogram(
|
||||
wave: xr.DataArray,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
hop_len = window_len * (1 - overlap_perc)
|
||||
nfft = int(window_len * sampling_rate)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
|
||||
# compute spec
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data,
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_len - hop_len),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_len,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def denoise_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: str = SPEC_SCALE,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
|
||||
if scale == "pcen":
|
||||
smoothing_constant = get_pcen_smoothing_constant(samplerate / 10)
|
||||
return audio.pcen(
|
||||
spec * (2**31),
|
||||
smooth=smoothing_constant,
|
||||
).astype(dtype)
|
||||
|
||||
if scale == "log":
|
||||
return log_scale(spec, dtype=dtype)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def log_scale(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / samplerate)
|
||||
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
)
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def get_pcen_smoothing_constant(
|
||||
sr: int,
|
||||
time_constant: float = 0.4,
|
||||
hop_length: int = 512,
|
||||
) -> float:
|
||||
t_frames = time_constant * sr / float(hop_length)
|
||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
29
batdetect2/data/types.py
Normal file
29
batdetect2/data/types.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import Annotated, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data.annotations import AnnotationFormats
|
||||
|
||||
|
||||
class Dataset(BaseConfig):
|
||||
"""Represents a collection of one or more DatasetSources.
|
||||
|
||||
In the context of batdetect2, a Dataset aggregates multiple `DatasetSource`
|
||||
instances. It serves as the primary unit for defining data splits,
|
||||
typically used for model training, validation, or testing phases.
|
||||
|
||||
Attributes:
|
||||
name: A descriptive name for the overall dataset
|
||||
(e.g., "UK Training Set").
|
||||
description: A detailed explanation of the dataset's purpose,
|
||||
composition, how it was assembled, or any specific characteristics.
|
||||
sources: A list containing the `DatasetSource` objects included in this
|
||||
dataset.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sources: List[
|
||||
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||
]
|
@ -1,4 +1,5 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
@ -219,7 +220,6 @@ def compute_call_interval(
|
||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||
|
||||
|
||||
|
||||
# NOTE: The order of the features in this dictionary is important. The
|
||||
# features are extracted in this order and the order of the columns in the
|
||||
# output csv file is determined by this order. In order to avoid breaking
|
||||
|
@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
|
||||
num_filts // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
|
@ -5,7 +5,10 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
|
||||
from batdetect2.train.legacy.train_utils import (
|
||||
get_genus_mapping,
|
||||
get_short_class_names,
|
||||
)
|
||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -1,16 +1,66 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
auc,
|
||||
balanced_accuracy_score,
|
||||
roc_curve,
|
||||
)
|
||||
import pandas as pd
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
|
||||
from batdetect2.train.targets import build_encoder, get_class_names
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
clip_prediction: data.ClipPrediction,
|
||||
) -> List[data.Match]:
|
||||
annotated_sound_events = [
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_sound_events = [
|
||||
sound_event_prediction
|
||||
for sound_event_prediction in clip_prediction.sound_events
|
||||
if sound_event_prediction.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
annotated_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in annotated_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
matches = []
|
||||
for id1, id2, affinity in match_geometries(
|
||||
annotated_geometries,
|
||||
predicted_geometries,
|
||||
):
|
||||
target = annotated_sound_events[id1] if id1 is not None else None
|
||||
source = predicted_sound_events[id2] if id2 is not None else None
|
||||
matches.append(
|
||||
data.Match(source=source, target=target, affinity=affinity)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
||||
ret = []
|
||||
|
||||
for match in matches:
|
||||
pass
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
# classification error
|
||||
pred_int = (pred > prob).astype(np.int)
|
||||
pred_int = (pred > prob).astype(np.int32)
|
||||
class_acc = (pred_int == gt).mean() * 100.0
|
||||
|
||||
# ROC - area under curve
|
||||
@ -25,7 +75,6 @@ def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
|
||||
def calc_average_precision(recall, precision):
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
@ -91,7 +140,6 @@ def compute_pre_rec(
|
||||
pred_class = []
|
||||
file_ids = []
|
||||
for pid, pp in enumerate(preds):
|
||||
|
||||
# filter predicted calls that are too near the start or end of the file
|
||||
file_dur = gts[pid]["duration"]
|
||||
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
||||
@ -141,7 +189,6 @@ def compute_pre_rec(
|
||||
gt_generic_class = []
|
||||
num_positives = 0
|
||||
for gg in gts:
|
||||
|
||||
# filter ground truth calls that are too near the start or end of the file
|
||||
file_dur = gg["duration"]
|
||||
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
||||
@ -205,7 +252,6 @@ def compute_pre_rec(
|
||||
|
||||
# valid detection that has not already been assigned
|
||||
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
||||
|
||||
count_as_true_pos = True
|
||||
if eval_mode == "top_class" and (
|
||||
gt_class[gt_id][det_ind] != pred_class[ind]
|
@ -12,15 +12,14 @@ import pandas as pd
|
||||
import torch
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.evaluate.legacy.evaluate_models as evl
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def get_blank_annotation(ip_str):
|
||||
|
||||
res = {}
|
||||
res["class_name"] = ""
|
||||
res["duration"] = -1
|
||||
@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names):
|
||||
|
||||
|
||||
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
||||
|
||||
res, ann = get_blank_annotation("Generated by Tadarida")
|
||||
|
||||
# create the annotations in the correct format
|
||||
@ -120,7 +118,6 @@ def load_sonobat_meta(
|
||||
class_names,
|
||||
only_accepted_species=True,
|
||||
):
|
||||
|
||||
sp_dict = {}
|
||||
for ss in class_names:
|
||||
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
||||
@ -182,7 +179,6 @@ def load_sonobat_meta(
|
||||
|
||||
|
||||
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
# create the annotations in the correct format
|
||||
res, ann = get_blank_annotation("Generated by Sonobat")
|
||||
res_c = copy.deepcopy(res)
|
||||
@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
|
||||
def bb_overlap(bb_g_in, bb_p_in):
|
||||
|
||||
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
||||
bb_g = [
|
||||
bb_g_in["start_time"],
|
||||
@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"op_dir",
|
@ -8,10 +8,10 @@ import torch.utils.data
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.parameters as parameters
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.legacy.audio_dataloader as adl
|
||||
import batdetect2.train.legacy.train_model as tm
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
import batdetect2.train.losses as losses
|
||||
import batdetect2.train.train_model as tm
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2 import types
|
||||
|
@ -1,11 +1,92 @@
|
||||
from batdetect2.models.feature_extractors import (
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
Net2DPlain,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"ClassifierHead",
|
||||
"ModelConfig",
|
||||
"ModelType",
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"build_architecture",
|
||||
"load_model_config",
|
||||
]
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
||||
|
||||
|
||||
def build_architecture(
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
||||
if config.name == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
||||
|
185
batdetect2/models/backbones.py
Normal file
185
batdetect2/models/backbones.py
Normal file
@ -0,0 +1,185 @@
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlock,
|
||||
Decoder,
|
||||
DownscalingLayer,
|
||||
Encoder,
|
||||
SelfAttention,
|
||||
UpscalingLayer,
|
||||
VerticalConv,
|
||||
)
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
]
|
||||
|
||||
|
||||
class Net2DPlain(BackboneModel):
|
||||
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
|
||||
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_height = input_height
|
||||
self.encoder_channels = tuple(encoder_channels)
|
||||
self.decoder_channels = tuple(decoder_channels)
|
||||
self.out_channels = out_channels
|
||||
|
||||
if len(encoder_channels) != len(decoder_channels):
|
||||
raise ValueError(
|
||||
f"Mismatched encoder and decoder channel lists. "
|
||||
f"The encoder has {len(encoder_channels)} channels "
|
||||
f"(implying {len(encoder_channels) - 1} layers), "
|
||||
f"while the decoder has {len(decoder_channels)} channels "
|
||||
f"(implying {len(decoder_channels) - 1} layers). "
|
||||
f"These lengths must be equal."
|
||||
)
|
||||
|
||||
self.divide_factor = 2 ** (len(encoder_channels) - 1)
|
||||
if self.input_height % self.divide_factor != 0:
|
||||
raise ValueError(
|
||||
f"Input height ({self.input_height}) must be divisible by "
|
||||
f"the divide factor ({self.divide_factor}). "
|
||||
f"This ensures proper upscaling after downscaling to recover "
|
||||
f"the original input height."
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
channels=encoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.downscaling_layer_type,
|
||||
)
|
||||
|
||||
self.conv_same_1 = ConvBlock(
|
||||
in_channels=encoder_channels[-1],
|
||||
out_channels=bottleneck_channels,
|
||||
)
|
||||
|
||||
# bottleneck
|
||||
self.conv_vert = VerticalConv(
|
||||
in_channels=bottleneck_channels,
|
||||
out_channels=bottleneck_channels,
|
||||
input_height=self.input_height // (2**self.encoder.depth),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
channels=decoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.upscaling_layer_type,
|
||||
)
|
||||
|
||||
self.conv_same_2 = ConvBlock(
|
||||
in_channels=decoder_channels[-1],
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
# bottleneck
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
|
||||
# decoder
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
# Restore original size
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.conv_same_2(x)
|
||||
|
||||
|
||||
class Net2DFast(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
):
|
||||
super().__init__(
|
||||
input_height=input_height,
|
||||
encoder_channels=encoder_channels,
|
||||
bottleneck_channels=bottleneck_channels,
|
||||
decoder_channels=decoder_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
# bottleneck
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
|
||||
# decoder
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
# Restore original size
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.conv_same_2(x)
|
||||
|
||||
|
||||
class Net2DFastNoAttn(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(Net2DFast):
|
||||
downscaling_layer_type = "ConvBlockDownStandard"
|
||||
upscaling_layer_type = "ConvBlockUpStandard"
|
||||
|
||||
|
||||
def pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
print(spec.shape)
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
|
||||
|
||||
def restore_pad(
|
||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||
) -> torch.Tensor:
|
||||
# Restore original size
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return x
|
@ -4,18 +4,32 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build
|
||||
complex neural network architectures.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SelfAttention",
|
||||
"ConvBlock",
|
||||
"ConvBlockDownCoordF",
|
||||
"ConvBlockDownStandard",
|
||||
"ConvBlockUpF",
|
||||
"ConvBlockUpStandard",
|
||||
"SelfAttention",
|
||||
"VerticalConv",
|
||||
"DownscalingLayer",
|
||||
"UpscalingLayer",
|
||||
]
|
||||
|
||||
|
||||
@ -25,16 +39,21 @@ class SelfAttention(nn.Module):
|
||||
This module implements self-attention mechanism.
|
||||
"""
|
||||
|
||||
def __init__(self, ip_dim: int, att_dim: int):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
attention_channels: int,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Note, does not encode position information (absolute or realtive)
|
||||
self.temperature = 1.0
|
||||
self.att_dim = att_dim
|
||||
self.key_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.val_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.que_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
||||
# Note, does not encode position information (absolute or relative)
|
||||
self.temperature = temperature
|
||||
self.att_dim = attention_channels
|
||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
@ -43,11 +62,11 @@ class SelfAttention(nn.Module):
|
||||
x, self.key_fun.weight.T
|
||||
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
query = torch.matmul(
|
||||
x, self.que_fun.weight.T
|
||||
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
x, self.query_fun.weight.T
|
||||
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
value = torch.matmul(
|
||||
x, self.val_fun.weight.T
|
||||
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
x, self.value_fun.weight.T
|
||||
) + self.value_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||
self.temperature * self.att_dim
|
||||
@ -63,6 +82,66 @@ class SelfAttention(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu_(self.conv_bn(self.conv(x)))
|
||||
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
"""Convolutional layer over full height.
|
||||
|
||||
This layer applies a convolution that captures information across the
|
||||
entire height of the input image. It uses a kernel with the same height as
|
||||
the input, effectively condensing the vertical information into a single
|
||||
output row.
|
||||
|
||||
More specifically:
|
||||
|
||||
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
||||
input channels, H is the image height, and W is the image width.
|
||||
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
||||
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
||||
convolution integrates information from all rows of the input.
|
||||
|
||||
This process effectively extracts features that span the full height of
|
||||
the input image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(input_height, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
"""Convolutional Block with Downsampling and Coord Feature.
|
||||
|
||||
@ -72,27 +151,27 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
ip_height: int,
|
||||
k_size: int = 3,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
|
||||
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn + 1,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
in_channels + 1,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
@ -110,26 +189,28 @@ class ConvBlockDownStandard(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
return x
|
||||
return F.relu(self.conv_bn(x), inplace=True)
|
||||
|
||||
|
||||
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
||||
|
||||
|
||||
class ConvBlockUpF(nn.Module):
|
||||
@ -141,10 +222,10 @@ class ConvBlockUpF(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
ip_height: int,
|
||||
k_size: int = 3,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
@ -154,15 +235,18 @@ class ConvBlockUpF(nn.Module):
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, ip_height * up_scale[0])[
|
||||
torch.linspace(-1, 1, input_height * up_scale[0])[
|
||||
None, None, ..., None
|
||||
],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
|
||||
in_channels + 1,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
op = F.interpolate(
|
||||
@ -189,9 +273,9 @@ class ConvBlockUpStandard(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
k_size: int = 3,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
@ -200,9 +284,12 @@ class ConvBlockUpStandard(nn.Module):
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn, out_chn, kernel_size=k_size, padding=pad_size
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
op = F.interpolate(
|
||||
@ -217,3 +304,143 @@ class ConvBlockUpStandard(nn.Module):
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
||||
|
||||
|
||||
def build_downscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: DownscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockDownStandard":
|
||||
return ConvBlockDownStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockDownCoordF":
|
||||
return ConvBlockDownCoordF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid downscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||
)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (1, 32, 62, 128),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
||||
] = "ConvBlockDownStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_downscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height // (2**layer_num),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
outputs = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def build_upscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: UpscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockUpStandard":
|
||||
return ConvBlockUpStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockUpF":
|
||||
return ConvBlockUpF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid upscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
||||
)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (256, 62, 32, 32),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockUpStandard", "ConvBlockUpF"
|
||||
] = "ConvBlockUpStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
self.depth = len(self.channels) - 1
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_upscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height
|
||||
// (2 ** (self.depth - layer_num)),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
f"Incorrect number of residuals provided. "
|
||||
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||
f"but got {len(residuals)}."
|
||||
)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
x = layer(x + res)
|
||||
|
||||
return x
|
||||
|
15
batdetect2/models/decoder.py
Normal file
15
batdetect2/models/decoder.py
Normal file
@ -0,0 +1,15 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
@ -1,139 +0,0 @@
|
||||
from typing import Type
|
||||
|
||||
import pytorch_lightning as L
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from torch import nn, optim
|
||||
|
||||
from batdetect2.data.preprocessing import (
|
||||
preprocess_audio_clip,
|
||||
PreprocessingConfig,
|
||||
)
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
from batdetect2.models.feature_extractors import Net2DFast
|
||||
from batdetect2.models.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
|
||||
from batdetect2.train import losses
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
class_mapper: ClassMapper,
|
||||
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
|
||||
learning_rate: float = 1e-3,
|
||||
input_height: int = 128,
|
||||
num_features: int = 32,
|
||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.preprocessing_config = preprocessing_config
|
||||
self.postprocessing_config = postprocessing_config
|
||||
self.class_mapper = class_mapper
|
||||
self.learning_rate = learning_rate
|
||||
self.input_height = input_height
|
||||
self.num_features = num_features
|
||||
self.num_classes = class_mapper.num_classes
|
||||
|
||||
self.feature_extractor = feature_extractor_class(
|
||||
input_height=input_height,
|
||||
num_features=num_features,
|
||||
)
|
||||
|
||||
self.classifier = nn.Conv2d(
|
||||
self.feature_extractor.num_features // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
self.feature_extractor.num_features // 4,
|
||||
2,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.feature_extractor(spec)
|
||||
classification_logits = self.classifier(features)
|
||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
||||
return ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=self.bbox(features),
|
||||
class_probs=classification_probs[:, :-1],
|
||||
features=features,
|
||||
)
|
||||
|
||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
||||
return preprocess_audio_clip(
|
||||
clip,
|
||||
config=self.preprocessing_config,
|
||||
)
|
||||
|
||||
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
|
||||
spectrogram = self.compute_spectrogram(clip)
|
||||
return self.feature_extractor(
|
||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
|
||||
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
|
||||
spectrogram = self.compute_spectrogram(clip)
|
||||
spec_tensor = (
|
||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
outputs = self(spec_tensor)
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
[clip],
|
||||
class_mapper=self.class_mapper,
|
||||
config=self.postprocessing_config,
|
||||
)[0]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
) -> torch.Tensor:
|
||||
detection_loss = losses.focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
)
|
||||
|
||||
size_loss = losses.bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = losses.focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
valid_mask=valid_mask,
|
||||
)
|
||||
|
||||
return detection_loss + size_loss + classification_loss
|
||||
|
||||
def training_step( # type: ignore
|
||||
self,
|
||||
batch: TrainExample,
|
||||
):
|
||||
outputs = self.forward(batch.spec)
|
||||
loss = self.compute_loss(outputs, batch)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||
return [optimizer], [scheduler]
|
15
batdetect2/models/encoder.py
Normal file
15
batdetect2/models/encoder.py
Normal file
@ -0,0 +1,15 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
@ -1,319 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlockDownCoordF,
|
||||
ConvBlockDownStandard,
|
||||
ConvBlockUpF,
|
||||
ConvBlockUpStandard,
|
||||
SelfAttention,
|
||||
)
|
||||
from batdetect2.models.typing import FeatureExtractorModel
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
]
|
||||
|
||||
|
||||
class Net2DFast(FeatureExtractorModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
# encoder
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
self.input_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
self.input_height // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
self.input_height // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
# bottleneck
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||
|
||||
# decoder
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
|
||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = (32 - h % 32) % 32
|
||||
w_pad = (32 - w % 32) % 32
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
# encoder
|
||||
spec, h_pad, w_pad = self.pad_adjust(spec)
|
||||
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
# bottleneck
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
# decoder
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
# Restore original size
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
|
||||
|
||||
class Net2DFastNoAttn(FeatureExtractorModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
self.input_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
self.input_height // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
self.input_height // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(FeatureExtractorModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownStandard(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownStandard(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownStandard(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpStandard(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpStandard(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpStandard(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
51
batdetect2/models/heads.py
Normal file
51
batdetect2/models/heads.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
__all__ = ["ClassifierHead"]
|
||||
|
||||
|
||||
class Output(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
def __init__(self, num_classes: int, in_channels: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.classifier = nn.Conv2d(
|
||||
self.in_channels,
|
||||
# Add one to account for the background class
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> Output:
|
||||
logits = self.classifier(features)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
|
||||
return Output(
|
||||
detection=detection_probs,
|
||||
classification=probs[:, :-1],
|
||||
)
|
||||
|
||||
|
||||
class BBoxHead(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=2,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
return self.bbox(features)
|
@ -1,12 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
"FeatureExtractorModel",
|
||||
"BackboneModel",
|
||||
]
|
||||
|
||||
|
||||
@ -41,16 +41,31 @@ class ModelOutput(NamedTuple):
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class FeatureExtractorModel(ABC, nn.Module):
|
||||
class BackboneModel(ABC, nn.Module):
|
||||
input_height: int
|
||||
"""Height of the input spectrogram."""
|
||||
|
||||
num_features: int
|
||||
"""Dimension of the feature tensor."""
|
||||
encoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the encoder. The length of this tuple determines the number of
|
||||
encoder layers."""
|
||||
|
||||
decoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the decoder. The length of this tuple determines the number of
|
||||
decoder layers."""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels in the bottleneck layer, which connects the
|
||||
encoder and decoder."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of channels in the final output feature map produced by the
|
||||
backbone model."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the encoder model."""
|
||||
"""Forward pass of the model."""
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
|
181
batdetect2/modules.py
Normal file
181
batdetect2/modules.py
Normal file
@ -0,0 +1,181 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.models import (
|
||||
BBoxHead,
|
||||
ClassifierHead,
|
||||
ModelConfig,
|
||||
build_architecture,
|
||||
)
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_encoder,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class ModuleConfig(BaseConfig):
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
architecture: ModelConfig = Field(default_factory=ModelConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocessing: PostprocessConfig = Field(
|
||||
default_factory=PostprocessConfig
|
||||
)
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
config: ModuleConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ModuleConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config or ModuleConfig()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.backbone = build_architecture(self.config.architecture)
|
||||
|
||||
self.classifier = ClassifierHead(
|
||||
num_classes=len(self.config.targets.classes),
|
||||
in_channels=self.backbone.out_channels,
|
||||
)
|
||||
|
||||
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
||||
|
||||
conf = self.config.train.loss.classification
|
||||
self.class_weights = (
|
||||
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||
)
|
||||
|
||||
# Training targets
|
||||
self.class_names = get_class_names(self.config.targets.classes)
|
||||
self.encoder = build_encoder(
|
||||
self.config.targets.classes,
|
||||
replacement_rules=self.config.targets.replace,
|
||||
)
|
||||
self.decoder = build_decoder(self.config.targets.classes)
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
||||
size_preds = self.bbox(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=size_preds,
|
||||
class_probs=classification_probs,
|
||||
features=features,
|
||||
)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("train/loss/detection", losses.total, logger=True)
|
||||
self.log("train/loss/size", losses.total, logger=True)
|
||||
self.log("train/loss/classification", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||
outputs = self.forward(batch.spec)
|
||||
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("val/loss/detection", losses.total, logger=True)
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
clip_prediction = postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip_annotation.clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
matches = match_predictions_and_annotations(
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
)
|
||||
|
||||
self.validation_predictions.extend(matches)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.validation_predictions.clear()
|
||||
|
||||
def configure_optimizers(self):
|
||||
conf = self.config.train.optimizer
|
||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def process_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[Path] = None,
|
||||
) -> data.ClipPrediction:
|
||||
spec = preprocess_audio_clip(
|
||||
clip,
|
||||
config=self.config.preprocessing,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||
outputs = self.forward(tensor)
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
@ -2,10 +2,10 @@
|
||||
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import matplotlib.ticker as tick
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
import matplotlib.ticker as tick
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
@ -102,7 +102,6 @@ def spectrogram(
|
||||
return ax
|
||||
|
||||
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
dets: List[Annotation],
|
||||
|
@ -17,6 +17,6 @@ def create_ax(
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||
|
||||
return ax # type: ignore
|
||||
return ax # type: ignore
|
||||
|
@ -1,19 +1,20 @@
|
||||
"""Module for postprocessing model outputs."""
|
||||
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"postprocess_model_outputs",
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
"postprocess_model_outputs",
|
||||
]
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
@ -21,7 +22,7 @@ DETECTION_THRESHOLD = 0.01
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseModel):
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration for postprocessing model outputs."""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
@ -31,14 +32,29 @@ class PostprocessConfig(BaseModel):
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
TagFunction = Callable[[int], List[data.Tag]]
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
detection_score: float
|
||||
class_scores: Dict[str, float]
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
class_mapper: ClassMapper,
|
||||
config: PostprocessConfig,
|
||||
classes: List[str],
|
||||
decoder: Callable[[str], List[data.Tag]],
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Postprocesses model outputs to generate clip predictions.
|
||||
|
||||
@ -68,6 +84,9 @@ def postprocess_model_outputs(
|
||||
ValueError
|
||||
If the number of predictions does not match the number of clips.
|
||||
"""
|
||||
|
||||
config = config or PostprocessConfig()
|
||||
|
||||
num_predictions = len(outputs.detection_probs)
|
||||
|
||||
if num_predictions == 0:
|
||||
@ -108,7 +127,8 @@ def postprocess_model_outputs(
|
||||
size_preds,
|
||||
class_probs,
|
||||
features,
|
||||
class_mapper=class_mapper,
|
||||
classes=classes,
|
||||
decoder=decoder,
|
||||
min_freq=config.min_freq,
|
||||
max_freq=config.max_freq,
|
||||
detection_threshold=config.detection_threshold,
|
||||
@ -124,6 +144,82 @@ def postprocess_model_outputs(
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_predictions_from_outputs(
|
||||
start: float,
|
||||
end: float,
|
||||
scores: torch.Tensor,
|
||||
y_pos: torch.Tensor,
|
||||
x_pos: torch.Tensor,
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
classes: List[str],
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
) -> List[RawPrediction]:
|
||||
_, freq_bins, time_bins = size_preds.shape
|
||||
|
||||
sorted_indices = torch.argsort(x_pos)
|
||||
valid_indices = sorted_indices[
|
||||
scores[sorted_indices] > detection_threshold
|
||||
]
|
||||
|
||||
scores = scores[valid_indices]
|
||||
x_pos = x_pos[valid_indices]
|
||||
y_pos = y_pos[valid_indices]
|
||||
|
||||
predictions: List[RawPrediction] = []
|
||||
for score, x, y in zip(scores, x_pos, y_pos):
|
||||
width, height = size_preds[:, y, x]
|
||||
class_prob = class_probs[:, y, x].detach().numpy()
|
||||
feats = features[:, y, x].detach().numpy()
|
||||
|
||||
start_time = np.interp(
|
||||
x.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
end_time = np.interp(
|
||||
x.item() + width.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
low_freq = np.interp(
|
||||
y.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
high_freq = np.interp(
|
||||
y.item() - height.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||
|
||||
predictions.append(
|
||||
RawPrediction(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
detection_score=score.item(),
|
||||
features=feats,
|
||||
class_scores={
|
||||
class_name: prob
|
||||
for class_name, prob in zip(classes, class_prob)
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_sound_events_from_outputs(
|
||||
clip: data.Clip,
|
||||
scores: torch.Tensor,
|
||||
@ -132,7 +228,8 @@ def compute_sound_events_from_outputs(
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
class_mapper: ClassMapper,
|
||||
classes: List[str],
|
||||
decoder: Callable[[str], List[data.Tag]],
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
@ -181,12 +278,13 @@ def compute_sound_events_from_outputs(
|
||||
predicted_tags: List[data.PredictedTag] = []
|
||||
|
||||
for label_id, class_score in enumerate(class_prob):
|
||||
corresponding_tags = class_mapper.inverse_transform(label_id)
|
||||
class_name = classes[label_id]
|
||||
corresponding_tags = decoder(class_name)
|
||||
predicted_tags.extend(
|
||||
[
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=class_score.item(),
|
||||
score=max(min(class_score.item(), 1), 0),
|
||||
)
|
||||
for tag in corresponding_tags
|
||||
]
|
||||
@ -207,7 +305,7 @@ def compute_sound_events_from_outputs(
|
||||
),
|
||||
features=[
|
||||
data.Feature(
|
||||
name=f"batdetect2_{i}",
|
||||
term=data.term_from_key(f"batdetect2_{i}"),
|
||||
value=value.item(),
|
||||
)
|
||||
for i, value in enumerate(feature)
|
||||
@ -217,7 +315,7 @@ def compute_sound_events_from_outputs(
|
||||
predictions.append(
|
||||
data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=score.item(),
|
||||
score=max(min(score.item(), 1), 0),
|
||||
tags=predicted_tags,
|
||||
)
|
||||
)
|
68
batdetect2/preprocess/__init__.py
Normal file
68
batdetect2/preprocess/__init__.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess.audio import (
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.preprocess.config import (
|
||||
PreprocessingConfig,
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
AmplitudeScaleConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
Scales,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
compute_spectrogram,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AmplitudeScaleConfig",
|
||||
"AudioConfig",
|
||||
"FrequencyConfig",
|
||||
"LogScaleConfig",
|
||||
"PcenScaleConfig",
|
||||
"PreprocessingConfig",
|
||||
"ResampleConfig",
|
||||
"STFTConfig",
|
||||
"Scales",
|
||||
"SpecSizeConfig",
|
||||
"SpectrogramConfig",
|
||||
"load_preprocessing_config",
|
||||
"preprocess_audio_clip",
|
||||
]
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
config
|
||||
Configuration for preprocessing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
|
||||
"""
|
||||
config = config or PreprocessingConfig()
|
||||
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
|
||||
return compute_spectrogram(wav, config=config.spectrogram)
|
61
batdetect2/preprocess/arrays.py
Normal file
61
batdetect2/preprocess/arrays.py
Normal file
@ -0,0 +1,61 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: np.ndarray,
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
|
||||
return np.pad(
|
||||
array,
|
||||
pad,
|
||||
mode="constant",
|
||||
constant_values=value,
|
||||
)
|
||||
|
||||
|
||||
def make_width_divisible(
|
||||
array: np.ndarray,
|
||||
factor: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
width = array.shape[axis]
|
||||
|
||||
if width % factor == 0:
|
||||
return array
|
||||
|
||||
extra = (-width) % factor
|
||||
return extend_width(array, extra, axis=axis, value=value)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
array: np.ndarray,
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
current_width = array.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return array
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
array,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
)
|
||||
|
||||
slices = [
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return array[tuple(slices)]
|
199
batdetect2/preprocess/audio.py
Normal file
199
batdetect2/preprocess/audio.py
Normal file
@ -0,0 +1,199 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import arrays, audio, data
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
SCALE_RAW_AUDIO = False
|
||||
DEFAULT_DURATION = None
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
mode: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
scale: bool = SCALE_RAW_AUDIO
|
||||
center: bool = True
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
recording = data.Recording.from_file(path)
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or AudioConfig()
|
||||
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
|
||||
)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
|
||||
if config.resample:
|
||||
wav = resample_audio(
|
||||
wav,
|
||||
samplerate=config.resample.samplerate,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if config.center:
|
||||
wav = ops.center(wav)
|
||||
|
||||
if config.scale:
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def adjust_audio_duration(
|
||||
wave: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
current_duration = end_time - start_time
|
||||
|
||||
if current_duration == duration:
|
||||
return wave
|
||||
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
mode: str = "poly",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
step = arrays.get_dim_step(wav, dim="time")
|
||||
original_samplerate = int(1 / step)
|
||||
|
||||
if original_samplerate == samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
if mode == "poly":
|
||||
resampled = resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
elif mode == "fourier":
|
||||
resampled = resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
times = np.linspace(
|
||||
start,
|
||||
stop + step,
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
times,
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: xr.DataArray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array.values,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: xr.DataArray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
ratio = sr_new / sr_orig
|
||||
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
|
31
batdetect2/preprocess/config.py
Normal file
31
batdetect2/preprocess/config.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
AudioConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
SpectrogramConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PreprocessingConfig",
|
||||
"load_preprocessing_config",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
323
batdetect2/preprocess/spectrogram.py
Normal file
323
batdetect2/preprocess/spectrogram.py
Normal file
@ -0,0 +1,323 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from soundevent import arrays, audio
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
max_freq: int = Field(default=120_000, gt=0)
|
||||
min_freq: int = Field(default=10_000, gt=0)
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
height: int = 128
|
||||
"""Height of the spectrogram in pixels. This value determines the
|
||||
number of frequency bands and corresponds to the vertical dimension
|
||||
of the spectrogram."""
|
||||
|
||||
resize_factor: Optional[float] = 0.5
|
||||
"""Factor by which to resize the spectrogram along the time axis.
|
||||
A value of 0.5 reduces the temporal dimension by half, while a
|
||||
value of 2.0 doubles it. If None, no resizing is performed."""
|
||||
|
||||
|
||||
class LogScaleConfig(BaseConfig):
|
||||
name: Literal["log"] = "log"
|
||||
|
||||
|
||||
class PcenScaleConfig(BaseConfig):
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
hop_length: int = 512
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class AmplitudeScaleConfig(BaseConfig):
|
||||
name: Literal["amplitude"] = "amplitude"
|
||||
|
||||
|
||||
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
scale: Scales = Field(
|
||||
default_factory=PcenScaleConfig,
|
||||
discriminator="name",
|
||||
)
|
||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
denoise: bool = True
|
||||
max_scale: bool = False
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
config: Optional[SpectrogramConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = crop_spectrogram_frequencies(
|
||||
spec,
|
||||
min_freq=config.frequencies.min_freq,
|
||||
max_freq=config.frequencies.max_freq,
|
||||
)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=config.scale)
|
||||
|
||||
if config.denoise:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if config.size:
|
||||
spec = resize_spectrogram(
|
||||
spec,
|
||||
height=config.size.height,
|
||||
resize_factor=config.size.resize_factor,
|
||||
)
|
||||
|
||||
if config.max_scale:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
|
||||
return spec.astype(dtype)
|
||||
|
||||
|
||||
def crop_spectrogram_frequencies(
|
||||
spec: xr.DataArray,
|
||||
min_freq: int = 10_000,
|
||||
max_freq: int = 120_000,
|
||||
) -> xr.DataArray:
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def stft(
|
||||
wave: xr.DataArray,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
nfft = int(window_duration * sampling_rate)
|
||||
noverlap = int(window_overlap * nfft)
|
||||
hop_len = nfft - noverlap
|
||||
hop_duration = hop_len / sampling_rate
|
||||
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data.astype(dtype),
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
window=window_fn,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_duration - hop_duration),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_duration,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: Scales,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if scale.name == "log":
|
||||
return scale_log(spec, dtype=dtype)
|
||||
|
||||
if scale.name == "pcen":
|
||||
return scale_pcen(
|
||||
spec,
|
||||
time_constant=scale.time_constant,
|
||||
hop_length=scale.hop_length,
|
||||
gain=scale.gain,
|
||||
power=scale.power,
|
||||
bias=scale.bias,
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def scale_pcen(
|
||||
spec: xr.DataArray,
|
||||
time_constant: float = 0.4,
|
||||
hop_length: int = 512,
|
||||
gain: float = 0.98,
|
||||
bias: float = 2,
|
||||
power: float = 0.5,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
return audio.pcen(
|
||||
spec * (2**31),
|
||||
smooth=smoothing_constant,
|
||||
gain=gain,
|
||||
bias=bias,
|
||||
power=power,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def resize_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
height: int = 128,
|
||||
resize_factor: Optional[float] = 0.5,
|
||||
) -> xr.DataArray:
|
||||
resize_factor = resize_factor or 1
|
||||
current_width = spec.sizes["time"]
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(resize_factor * current_width),
|
||||
frequency=height,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def adjust_spectrogram_width(
|
||||
spec: xr.DataArray,
|
||||
divide_factor: int = 32,
|
||||
time_period: float = 0.001,
|
||||
) -> xr.DataArray:
|
||||
time_width = spec.sizes["time"]
|
||||
|
||||
if time_width % divide_factor == 0:
|
||||
return spec
|
||||
|
||||
target_size = int(
|
||||
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
|
||||
)
|
||||
extra_duration = (target_size - time_width) * time_period
|
||||
_, stop = arrays.get_dim_range(spec, dim="time")
|
||||
resized = ops.extend_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
stop=stop + extra_duration,
|
||||
)
|
||||
return resized
|
||||
|
||||
|
||||
def duration_to_spec_width(
|
||||
duration: float,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
samples = int(duration * samplerate)
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
width = (samples - fft_len + hop_len) / hop_len
|
||||
return int(np.floor(width))
|
||||
|
||||
|
||||
def spec_width_to_samples(
|
||||
width: int,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
return width * hop_len + fft_len - hop_len
|
||||
|
||||
|
||||
def get_spectrogram_resolution(
|
||||
config: SpectrogramConfig,
|
||||
) -> tuple[float, float]:
|
||||
max_freq = config.frequencies.max_freq
|
||||
min_freq = config.frequencies.min_freq
|
||||
assert config.size is not None
|
||||
|
||||
spec_height = config.size.height
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
freq_bin_width = (max_freq - min_freq) / spec_height
|
||||
hop_duration = config.stft.window_duration * (
|
||||
1 - config.stft.window_overlap
|
||||
)
|
||||
return freq_bin_width, hop_duration / resize_factor
|
76
batdetect2/preprocess/tensors.py
Normal file
76
batdetect2/preprocess/tensors.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
pad = [
|
||||
[0, 0] if index != axis else [0, extra]
|
||||
for index in range(axis, dims)[::-1]
|
||||
]
|
||||
return F.pad(
|
||||
array,
|
||||
[x for y in pad for x in y],
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def make_width_divisible(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
factor: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
width = array.shape[axis]
|
||||
|
||||
if width % factor == 0:
|
||||
return array
|
||||
|
||||
extra = (-width) % factor
|
||||
return extend_width(array, extra, axis=axis, value=value)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
current_width = array.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return array
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
array,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
)
|
||||
|
||||
slices = [
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return array[tuple(slices)]
|
88
batdetect2/terms.py
Normal file
88
batdetect2/terms.py
Normal file
@ -0,0 +1,88 @@
|
||||
from inspect import getmembers
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from soundevent import data, terms
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"get_term_from_info",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
label: Optional[str]
|
||||
name: Optional[str]
|
||||
uri: Optional[str]
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
value: str
|
||||
term: Optional[TermInfo] = None
|
||||
key: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
name="soundevent:call_type",
|
||||
label="Call Type",
|
||||
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
||||
)
|
||||
|
||||
individual = data.Term(
|
||||
name="soundevent:individual",
|
||||
label="Individual",
|
||||
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
||||
)
|
||||
|
||||
|
||||
ALL_TERMS = [
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
call_type,
|
||||
individual,
|
||||
]
|
||||
|
||||
|
||||
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
||||
for term in ALL_TERMS:
|
||||
if term_info.name and term_info.name == term.name:
|
||||
return term
|
||||
|
||||
if term_info.label and term_info.label == term.label:
|
||||
return term
|
||||
|
||||
if term_info.uri and term_info.uri == term.uri:
|
||||
return term
|
||||
|
||||
if term_info.name is None:
|
||||
if term_info.label is None:
|
||||
raise ValueError("At least one of name or label must be provided.")
|
||||
|
||||
term_info.name = (
|
||||
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
||||
)
|
||||
|
||||
if term_info.label is None:
|
||||
term_info.label = term_info.name
|
||||
|
||||
return data.Term(
|
||||
name=term_info.name,
|
||||
label=term_info.label,
|
||||
uri=term_info.uri,
|
||||
definition="Unknown",
|
||||
)
|
||||
|
||||
|
||||
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||
if tag_info.term:
|
||||
term = get_term_from_info(tag_info.term)
|
||||
elif tag_info.key:
|
||||
term = data.term_from_key(tag_info.key)
|
||||
else:
|
||||
raise ValueError("Either term or key must be provided in tag info.")
|
||||
|
||||
return data.Tag(term=term, value=tag_info.value)
|
@ -0,0 +1,48 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
add_echo,
|
||||
augment_example,
|
||||
load_agumentation_config,
|
||||
mask_frequency,
|
||||
mask_time,
|
||||
mix_examples,
|
||||
scale_volume,
|
||||
select_subclip,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
SubclipConfig,
|
||||
TrainExample,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||
from batdetect2.train.preprocess import preprocess_annotations
|
||||
from batdetect2.train.targets import TargetConfig, load_target_config
|
||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"LabelConfig",
|
||||
"LabeledDataset",
|
||||
"SubclipConfig",
|
||||
"TargetConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"add_echo",
|
||||
"augment_example",
|
||||
"load_agumentation_config",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
"load_train_config",
|
||||
"load_trainer_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_examples",
|
||||
"preprocess_annotations",
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
]
|
@ -1,941 +0,0 @@
|
||||
"""Functions and dataloaders for training and testing the model."""
|
||||
|
||||
import copy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
|
||||
import batdetect2.utils.audio_utils as au
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AudioLoaderAnnotationGroup,
|
||||
AudioLoaderParameters,
|
||||
FileAnnotation,
|
||||
)
|
||||
|
||||
|
||||
def generate_gt_heatmaps(
|
||||
spec_op_shape: Tuple[int, int],
|
||||
sampling_rate: float,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
class_names: List[str],
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
resize_factor: float,
|
||||
target_sigma: float,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
|
||||
"""Generate ground truth heatmaps from annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec_op_shape : Tuple[int, int]
|
||||
Shape of the input spectrogram.
|
||||
sampling_rate : int
|
||||
Sampling rate of the input audio in Hz.
|
||||
ann : AnnotationGroup
|
||||
Dictionary containing the annotation information.
|
||||
params : HeatmapParameters
|
||||
Parameters controlling the generation of the heatmaps.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_2d_det : np.ndarray
|
||||
2D heatmap of the presence of an event.
|
||||
y_2d_size : np.ndarray
|
||||
2D heatmap of the size of the bounding box associated to event.
|
||||
y_2d_classes : np.ndarray
|
||||
3D array containing the ground-truth class probabilities for each
|
||||
pixel.
|
||||
ann_aug : AnnotationGroup
|
||||
A dictionary containing the annotation information of the
|
||||
annotations that are within the input spectrogram, augmented with
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(class_names)
|
||||
op_height = spec_op_shape[0]
|
||||
op_width = spec_op_shape[1]
|
||||
freq_per_bin = (max_freq - min_freq) / op_height
|
||||
|
||||
# start and end times
|
||||
x_pos_start = au.time_to_x_coords(
|
||||
ann["start_times"],
|
||||
sampling_rate,
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
)
|
||||
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
|
||||
x_pos_end = au.time_to_x_coords(
|
||||
ann["end_times"],
|
||||
sampling_rate,
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
)
|
||||
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
|
||||
|
||||
# location on y axis i.e. frequency
|
||||
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
||||
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
|
||||
# Only include annotations that are within the input spectrogram
|
||||
valid_inds = np.where(
|
||||
(x_pos_start >= 0)
|
||||
& (x_pos_start < op_width)
|
||||
& (y_pos_low >= 0)
|
||||
& (y_pos_low < (op_height - 1))
|
||||
)[0]
|
||||
|
||||
ann_aug: AudioLoaderAnnotationGroup = {
|
||||
**ann,
|
||||
"start_times": ann["start_times"][valid_inds],
|
||||
"end_times": ann["end_times"][valid_inds],
|
||||
"high_freqs": ann["high_freqs"][valid_inds],
|
||||
"low_freqs": ann["low_freqs"][valid_inds],
|
||||
"class_ids": ann["class_ids"][valid_inds],
|
||||
"individual_ids": ann["individual_ids"][valid_inds],
|
||||
"x_inds": x_pos_start[valid_inds],
|
||||
"y_inds": y_pos_low[valid_inds],
|
||||
}
|
||||
|
||||
# if the number of calls is only 1, then it is unique
|
||||
# TODO would be better if we found these unique calls at the merging stage
|
||||
if len(ann_aug["individual_ids"]) == 1:
|
||||
ann_aug["individual_ids"][0] = 0
|
||||
|
||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||
|
||||
# num classes and "background" class
|
||||
y_2d_classes: np.ndarray = np.zeros(
|
||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||
)
|
||||
|
||||
# create 2D ground truth heatmaps
|
||||
for ii in valid_inds:
|
||||
draw_gaussian(
|
||||
y_2d_det[0, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
target_sigma,
|
||||
)
|
||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||
|
||||
cls_id = ann["class_ids"][ii]
|
||||
if cls_id > -1:
|
||||
draw_gaussian(
|
||||
y_2d_classes[cls_id, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
target_sigma,
|
||||
)
|
||||
|
||||
# be careful as this will have a 1.0 places where we have event but
|
||||
# dont know gt class this will be masked in training anyway
|
||||
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
||||
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
||||
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
||||
|
||||
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
||||
|
||||
|
||||
def draw_gaussian(
|
||||
heatmap: np.ndarray,
|
||||
center: Tuple[int, int],
|
||||
sigmax: float,
|
||||
sigmay: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""Draw a 2D gaussian into the heatmap.
|
||||
|
||||
If the gaussian center is outside the heatmap, then the gaussian is not
|
||||
drawn.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
heatmap : np.ndarray
|
||||
The heatmap to draw into. Should be of shape (height, width).
|
||||
center : Tuple[int, int]
|
||||
The center of the gaussian in (x, y) format.
|
||||
sigmax : float
|
||||
The standard deviation of the gaussian in the x direction.
|
||||
sigmay : Optional[float], optional
|
||||
The standard deviation of the gaussian in the y direction. If None,
|
||||
then sigmay = sigmax, by default None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the gaussian was drawn, False if it was not (because
|
||||
the center was outside the heatmap).
|
||||
|
||||
|
||||
"""
|
||||
# center is (x, y)
|
||||
# this edits the heatmap inplace
|
||||
|
||||
if sigmay is None:
|
||||
sigmay = sigmax
|
||||
tmp_size = np.maximum(sigmax, sigmay) * 3
|
||||
mu_x = int(center[0] + 0.5)
|
||||
mu_y = int(center[1] + 0.5)
|
||||
w, h = heatmap.shape[0], heatmap.shape[1]
|
||||
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
||||
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
||||
|
||||
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
|
||||
return False
|
||||
|
||||
size = 2 * tmp_size + 1
|
||||
x = np.arange(0, size, 1, np.float32)
|
||||
y = x[:, np.newaxis]
|
||||
x0 = y0 = size // 2
|
||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
g = np.exp(
|
||||
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
|
||||
)
|
||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
||||
img_x = max(0, ul[0]), min(br[0], h)
|
||||
img_y = max(0, ul[1]), min(br[1], w)
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
|
||||
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
||||
"""Pad array with -1s."""
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
|
||||
|
||||
|
||||
def warp_spec_aug(
|
||||
spec: torch.Tensor,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
stretch_squeeze_delta: float,
|
||||
) -> torch.Tensor:
|
||||
"""Warp spectrogram by randomly stretching and squeezing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to warp.
|
||||
ann: AnnotationGroup
|
||||
Annotation group for the spectrogram. Must be provided to sync
|
||||
the start and stop times with the spectrogram after warping.
|
||||
stretch_squeeze_delta: float
|
||||
Maximum amount to stretch or squeeze the spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Warped spectrogram.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function modifies the annotation group in place.
|
||||
"""
|
||||
# Augment spectrogram by randomly stretch and squeezing
|
||||
# NOTE this also changes the start and stop time in place
|
||||
|
||||
delta = stretch_squeeze_delta
|
||||
op_size = (spec.shape[1], spec.shape[2])
|
||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||
|
||||
if resize_amt >= spec.shape[2]:
|
||||
spec_r = torch.cat(
|
||||
(
|
||||
spec,
|
||||
torch.zeros(
|
||||
(1, spec.shape[1], resize_amt - spec.shape[2]),
|
||||
dtype=spec.dtype,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
spec_r = spec[:, :, :resize_amt]
|
||||
|
||||
# Resize the spectrogram
|
||||
spec = F.interpolate(
|
||||
spec_r.unsqueeze(0),
|
||||
size=op_size,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# Update the start and stop times
|
||||
ann["start_times"] *= 1.0 / resize_fract_r
|
||||
ann["end_times"] *= 1.0 / resize_fract_r
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def mask_time_aug(
|
||||
spec: torch.Tensor,
|
||||
mask_max_time_perc: float,
|
||||
) -> torch.Tensor:
|
||||
"""Mask out random blocks of time.
|
||||
|
||||
Will randomly mask out a block of time in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_time_perc` of the total time.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
mask_max_time_perc: float
|
||||
Maximum percentage of time to mask out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out time blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.TimeMasking(
|
||||
int(spec.shape[1] * mask_max_time_perc)
|
||||
)
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mask_freq_aug(
|
||||
spec: torch.Tensor,
|
||||
mask_max_freq_perc: float,
|
||||
) -> torch.Tensor:
|
||||
"""Mask out random blocks of frequency.
|
||||
|
||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
||||
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
|
||||
A random number of blocks will be masked out between 1 and 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
mask_max_freq_perc: float
|
||||
Maximum percentage of frequency to mask out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram with masked out frequency blocks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is based on the implementation in::
|
||||
|
||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.FrequencyMasking(
|
||||
int(spec.shape[1] * mask_max_freq_perc)
|
||||
)
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def scale_vol_aug(
|
||||
spec: torch.Tensor,
|
||||
spec_amp_scaling: float,
|
||||
) -> torch.Tensor:
|
||||
"""Scale the volume of the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to scale.
|
||||
spec_amp_scaling: float
|
||||
Maximum scaling factor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
"""
|
||||
return spec * np.random.random() * spec_amp_scaling
|
||||
|
||||
|
||||
def echo_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
echo_max_delay: float,
|
||||
) -> np.ndarray:
|
||||
"""Add echo to audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to add echo to.
|
||||
sampling_rate: float
|
||||
Sampling rate of the audio.
|
||||
echo_max_delay: float
|
||||
Maximum delay of the echo in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio with echo added.
|
||||
"""
|
||||
sample_offset = (
|
||||
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
# NOTE: This seems to be wrong, as the echo should be added to the
|
||||
# end of the audio, not the beginning.
|
||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||
return audio
|
||||
|
||||
|
||||
def resample_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
resize_factor: float,
|
||||
spec_divide_factor: float,
|
||||
spec_train_width: int,
|
||||
aug_sampling_rates: List[int],
|
||||
) -> Tuple[np.ndarray, float, float]:
|
||||
"""Resample audio augmentation.
|
||||
|
||||
Will resample the audio to a random sampling rate from the list of
|
||||
sampling rates in `aug_sampling_rates`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate: float
|
||||
Original sampling rate of the audio.
|
||||
fft_win_length: float
|
||||
Length of the FFT window in seconds.
|
||||
fft_overlap: float
|
||||
Amount of overlap between FFT windows.
|
||||
resize_factor: float
|
||||
Factor to resize the spectrogram by.
|
||||
spec_divide_factor: float
|
||||
Factor to divide the spectrogram by.
|
||||
spec_train_width: int
|
||||
Width of the spectrogram.
|
||||
aug_sampling_rates: List[int]
|
||||
List of sampling rates to resample to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate : float
|
||||
New sampling rate.
|
||||
duration : float
|
||||
Duration of the audio in seconds.
|
||||
"""
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(aug_sampling_rates)
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
audio = au.pad_audio(
|
||||
audio,
|
||||
sampling_rate,
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
resize_factor,
|
||||
spec_divide_factor,
|
||||
spec_train_width,
|
||||
)
|
||||
duration = audio.shape[0] / float(sampling_rate)
|
||||
return audio, sampling_rate, duration
|
||||
|
||||
|
||||
def resample_audio(
|
||||
num_samples: int,
|
||||
sampling_rate: float,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: float,
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Resample audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_samples: int
|
||||
Expected number of samples for the output audio.
|
||||
sampling_rate: float
|
||||
Original sampling rate of the audio.
|
||||
audio2: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate2: float
|
||||
Target sampling rate of the audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio2 : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate2 : float
|
||||
New sampling rate.
|
||||
"""
|
||||
# resample to target sampling rate
|
||||
if sampling_rate != sampling_rate2:
|
||||
audio2 = librosa.resample(
|
||||
audio2,
|
||||
orig_sr=sampling_rate2,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
sampling_rate2 = sampling_rate
|
||||
|
||||
# pad or trim to the correct length
|
||||
if audio2.shape[0] < num_samples:
|
||||
audio2 = np.hstack(
|
||||
(
|
||||
audio2,
|
||||
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
|
||||
)
|
||||
)
|
||||
elif audio2.shape[0] > num_samples:
|
||||
audio2 = audio2[:num_samples]
|
||||
|
||||
return audio2, sampling_rate2
|
||||
|
||||
|
||||
def combine_audio_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: float,
|
||||
ann2: AudioLoaderAnnotationGroup,
|
||||
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
|
||||
"""Combine two audio files.
|
||||
|
||||
Will combine two audio files by resampling them to the same sampling rate
|
||||
and then combining them with a random weight. The annotations will be
|
||||
combined by taking the union of the two sets of annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
First Audio to combine.
|
||||
sampling_rate: int
|
||||
Sampling rate of the first audio.
|
||||
ann: AnnotationGroup
|
||||
Annotations for the first audio.
|
||||
audio2: np.ndarray
|
||||
Second Audio to combine.
|
||||
sampling_rate2: int
|
||||
Sampling rate of the second audio.
|
||||
ann2: AnnotationGroup
|
||||
Annotations for the second audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Combined audio.
|
||||
ann : AnnotationGroup
|
||||
Combined annotations.
|
||||
"""
|
||||
# resample so they are the same
|
||||
audio2, sampling_rate2 = resample_audio(
|
||||
audio.shape[0],
|
||||
sampling_rate,
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
)
|
||||
|
||||
# # set mean and std to be the same
|
||||
# audio2 = (audio2 - audio2.mean())
|
||||
# audio2 = (audio2/audio2.std())*audio.std()
|
||||
# audio2 = audio2 + audio.mean()
|
||||
|
||||
if (
|
||||
ann.get("annotated", False)
|
||||
and (ann2.get("annotated", False))
|
||||
and (sampling_rate2 == sampling_rate)
|
||||
and (audio.shape[0] == audio2.shape[0])
|
||||
):
|
||||
comb_weight = 0.3 + np.random.random() * 0.4
|
||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
||||
for kk in ann.keys():
|
||||
# when combining calls from different files, assume they come
|
||||
# from different individuals
|
||||
if kk == "individual_ids":
|
||||
if (ann[kk] > -1).sum() > 0:
|
||||
ann2[kk][ann2[kk] > -1] += (
|
||||
np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
)
|
||||
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
||||
|
||||
return audio, ann
|
||||
|
||||
|
||||
def _prepare_annotation(
|
||||
annotation: Annotation,
|
||||
class_names: List[str],
|
||||
) -> Annotation:
|
||||
try:
|
||||
class_id = class_names.index(annotation["class"])
|
||||
except ValueError:
|
||||
class_id = -1
|
||||
|
||||
ann: Annotation = {
|
||||
**annotation,
|
||||
"class_id": class_id,
|
||||
}
|
||||
|
||||
if "individual" in ann:
|
||||
ann["individual"] = int(ann["individual"]) # type: ignore
|
||||
|
||||
return ann
|
||||
|
||||
|
||||
def _prepare_file_annotation(
|
||||
annotation: FileAnnotation,
|
||||
class_names: List[str],
|
||||
classes_to_ignore: List[str],
|
||||
) -> AudioLoaderAnnotationGroup:
|
||||
annotations = [
|
||||
_prepare_annotation(ann, class_names)
|
||||
for ann in annotation["annotation"]
|
||||
if ann["class"] not in classes_to_ignore
|
||||
]
|
||||
|
||||
try:
|
||||
class_id_file = class_names.index(annotation["class_name"])
|
||||
except ValueError:
|
||||
class_id_file = -1
|
||||
|
||||
ret: AudioLoaderAnnotationGroup = {
|
||||
"id": annotation["id"],
|
||||
"annotated": annotation["annotated"],
|
||||
"duration": annotation["duration"],
|
||||
"issues": annotation["issues"],
|
||||
"time_exp": annotation["time_exp"],
|
||||
"class_name": annotation["class_name"],
|
||||
"notes": annotation["notes"],
|
||||
"annotation": annotations,
|
||||
"start_times": np.array([ann["start_time"] for ann in annotations]),
|
||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
||||
"class_ids": np.array(
|
||||
[ann.get("class_id", -1) for ann in annotations]
|
||||
),
|
||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class AudioLoader(torch.utils.data.Dataset):
|
||||
"""Main AudioLoader for training and testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_anns_ip: List[FileAnnotation],
|
||||
params: AudioLoaderParameters,
|
||||
dataset_name: Optional[str] = None,
|
||||
is_train: bool = False,
|
||||
return_spec_for_viz: bool = False,
|
||||
):
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.return_spec_for_viz = return_spec_for_viz
|
||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
||||
_prepare_file_annotation(
|
||||
ann,
|
||||
params["class_names"],
|
||||
params["classes_to_ignore"],
|
||||
)
|
||||
for ann in data_anns_ip
|
||||
]
|
||||
|
||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||
self.max_num_anns = 2 * np.max(
|
||||
ann_cnt
|
||||
) # x2 because we may be combining files during training
|
||||
|
||||
print("\n")
|
||||
if dataset_name is not None:
|
||||
print("Dataset : " + dataset_name)
|
||||
if self.is_train:
|
||||
print("Split type : train")
|
||||
else:
|
||||
print("Split type : test")
|
||||
print("Num files : " + str(len(self.data_anns)))
|
||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
||||
|
||||
def get_file_and_anns(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
|
||||
"""Get an audio file and its annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int, optional
|
||||
Index of the file to be loaded. If None, a random file is chosen.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio_raw : np.ndarray
|
||||
Loaded audio file.
|
||||
sampling_rate : float
|
||||
Sampling rate of the audio file.
|
||||
duration : float
|
||||
Duration of the audio file in seconds.
|
||||
ann : AnnotationGroup
|
||||
AnnotationGroup object containing the annotations for the audio file.
|
||||
"""
|
||||
# if no file specified, choose random one
|
||||
if index is None:
|
||||
index = np.random.randint(0, len(self.data_anns))
|
||||
|
||||
audio_file = self.data_anns[index]["file_path"]
|
||||
sampling_rate, audio_raw = au.load_audio(
|
||||
audio_file,
|
||||
self.data_anns[index]["time_exp"],
|
||||
self.params["target_samp_rate"],
|
||||
self.params["scale_raw_audio"],
|
||||
)
|
||||
|
||||
# copy annotation
|
||||
ann = copy.deepcopy(self.data_anns[index])
|
||||
# ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann[kk] = self.data_anns[index][kk].copy()
|
||||
|
||||
# if train then grab a random crop
|
||||
if self.is_train:
|
||||
nfft = int(self.params["fft_win_length"] * sampling_rate)
|
||||
noverlap = int(self.params["fft_overlap"] * nfft)
|
||||
length_samples = (
|
||||
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
|
||||
)
|
||||
|
||||
if audio_raw.shape[0] - length_samples > 0:
|
||||
sample_crop = np.random.randint(
|
||||
audio_raw.shape[0] - length_samples
|
||||
)
|
||||
else:
|
||||
sample_crop = 0
|
||||
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
|
||||
ann["start_times"] = ann["start_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
ann["end_times"] = ann["end_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
|
||||
# pad audio
|
||||
if self.is_train:
|
||||
op_spec_target_size = self.params["spec_train_width"]
|
||||
else:
|
||||
op_spec_target_size = None
|
||||
audio_raw = au.pad_audio(
|
||||
audio_raw,
|
||||
sampling_rate,
|
||||
self.params["fft_win_length"],
|
||||
self.params["fft_overlap"],
|
||||
self.params["resize_factor"],
|
||||
self.params["spec_divide_factor"],
|
||||
op_spec_target_size,
|
||||
)
|
||||
duration = audio_raw.shape[0] / float(sampling_rate)
|
||||
|
||||
# sort based on time
|
||||
inds = np.argsort(ann["start_times"])
|
||||
for kk in ann.keys():
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = ann[kk][inds]
|
||||
|
||||
return audio_raw, sampling_rate, duration, ann
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Get an item from the dataset."""
|
||||
# load audio file
|
||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
||||
|
||||
# augment on raw audio
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
# augment - combine with random audio file
|
||||
if (
|
||||
self.params["augment_at_train_combine"]
|
||||
and np.random.random() < self.params["aug_prob"]
|
||||
):
|
||||
(
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
_,
|
||||
ann2,
|
||||
) = self.get_file_and_anns()
|
||||
audio, ann = combine_audio_aug(
|
||||
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
|
||||
)
|
||||
|
||||
# simulate echo by adding delayed copy of the file
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
audio = echo_aug(
|
||||
audio,
|
||||
sampling_rate,
|
||||
echo_max_delay=self.params["echo_max_delay"],
|
||||
)
|
||||
|
||||
# resample the audio
|
||||
# if np.random.random() < self.params["aug_prob"]:
|
||||
# audio, sampling_rate, duration = resample_aug(
|
||||
# audio, sampling_rate, self.params
|
||||
# )
|
||||
|
||||
# create spectrogram
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params=dict(
|
||||
fft_win_length=self.params["fft_win_length"],
|
||||
fft_overlap=self.params["fft_overlap"],
|
||||
max_freq=self.params["max_freq"],
|
||||
min_freq=self.params["min_freq"],
|
||||
spec_scale=self.params["spec_scale"],
|
||||
denoise_spec_avg=self.params["denoise_spec_avg"],
|
||||
max_scale_spec=self.params["max_scale_spec"],
|
||||
),
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(self.params["spec_height"] * rsf),
|
||||
int(spec.shape[1] * rsf),
|
||||
)
|
||||
|
||||
# resize the spec
|
||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
||||
spec = F.interpolate(
|
||||
spec,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(0)
|
||||
|
||||
# augment spectrogram
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = scale_vol_aug(
|
||||
spec,
|
||||
spec_amp_scaling=self.params["spec_amp_scaling"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = warp_spec_aug(
|
||||
spec,
|
||||
ann,
|
||||
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_time_aug(
|
||||
spec,
|
||||
mask_max_time_perc=self.params["mask_max_time_perc"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_freq_aug(
|
||||
spec,
|
||||
mask_max_freq_perc=self.params["mask_max_freq_perc"],
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
outputs["spec"] = spec
|
||||
if self.return_spec_for_viz:
|
||||
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
|
||||
0
|
||||
)
|
||||
|
||||
# create ground truth heatmaps
|
||||
(
|
||||
outputs["y_2d_det"],
|
||||
outputs["y_2d_size"],
|
||||
outputs["y_2d_classes"],
|
||||
ann_aug,
|
||||
) = generate_gt_heatmaps(
|
||||
spec_op_shape,
|
||||
sampling_rate,
|
||||
ann,
|
||||
class_names=self.params["class_names"],
|
||||
fft_win_length=self.params["fft_win_length"],
|
||||
fft_overlap=self.params["fft_overlap"],
|
||||
max_freq=self.params["max_freq"],
|
||||
min_freq=self.params["min_freq"],
|
||||
resize_factor=self.params["resize_factor"],
|
||||
target_sigma=self.params["target_sigma"],
|
||||
)
|
||||
|
||||
# hack to get around requirement that all vectors are the same length
|
||||
# in the output batch
|
||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
||||
outputs["is_valid"] = pad_aray(
|
||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
||||
)
|
||||
keys = [
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
"x_inds",
|
||||
"y_inds",
|
||||
"start_times",
|
||||
"end_times",
|
||||
"low_freqs",
|
||||
"high_freqs",
|
||||
]
|
||||
for kk in keys:
|
||||
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
|
||||
|
||||
# convert to pytorch
|
||||
for kk in outputs.keys():
|
||||
if type(outputs[kk]) != torch.Tensor:
|
||||
outputs[kk] = torch.from_numpy(outputs[kk])
|
||||
|
||||
# scalars
|
||||
outputs["class_id_file"] = ann["class_id_file"]
|
||||
outputs["annotated"] = ann["annotated"]
|
||||
outputs["duration"] = duration
|
||||
outputs["sampling_rate"] = sampling_rate
|
||||
outputs["file_id"] = index
|
||||
|
||||
return outputs
|
||||
|
||||
def __len__(self):
|
||||
"""Denotes the total number of samples."""
|
||||
return len(self.data_anns)
|
@ -1,136 +1,214 @@
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from pydantic import Field
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||
from batdetect2.preprocess.arrays import adjust_width
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
|
||||
AUGMENTATION_PROBABILITY = 0.2
|
||||
MAX_DELAY = 0.005
|
||||
STRETCH_SQUEEZE_DELTA = 0.04
|
||||
MASK_MAX_TIME_PERC: float = 0.05
|
||||
MASK_MAX_FREQ_PERC: float = 0.10
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"load_agumentation_config",
|
||||
"select_subclip",
|
||||
"mix_examples",
|
||||
"add_echo",
|
||||
"scale_volume",
|
||||
"warp_spectrogram",
|
||||
"mask_time",
|
||||
"mask_frequency",
|
||||
"augment_example",
|
||||
]
|
||||
|
||||
|
||||
def maybe_apply(
|
||||
augmentation: Callable,
|
||||
prob: float = AUGMENTATION_PROBABILITY,
|
||||
) -> Callable:
|
||||
"""Apply an augmentation with a given probability."""
|
||||
|
||||
@wraps(augmentation)
|
||||
def _augmentation(x):
|
||||
if np.random.rand() > prob:
|
||||
return x
|
||||
return augmentation(x)
|
||||
|
||||
return _augmentation
|
||||
class BaseAugmentationConfig(BaseConfig):
|
||||
enable: bool = True
|
||||
probability: float = 0.2
|
||||
|
||||
|
||||
def select_random_subclip(
|
||||
train_example: xr.Dataset,
|
||||
def select_subclip(
|
||||
example: xr.Dataset,
|
||||
start_time: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
proportion: float = 0.9,
|
||||
width: Optional[int] = None,
|
||||
random: bool = False,
|
||||
) -> xr.Dataset:
|
||||
"""Select a random subclip from a clip."""
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||
|
||||
time_coords = train_example.coords["time"]
|
||||
if width is None:
|
||||
if duration is None:
|
||||
raise ValueError("Either duration or width must be provided")
|
||||
|
||||
start_time = time_coords.attrs.get("min", time_coords.min())
|
||||
end_time = time_coords.attrs.get("max", time_coords.max())
|
||||
width = int(np.floor(duration / step))
|
||||
|
||||
if duration is None:
|
||||
duration = (end_time - start_time) * proportion
|
||||
duration = width * step
|
||||
|
||||
start_time = np.random.uniform(start_time, end_time - duration)
|
||||
return train_example.sel(time=slice(start_time, start_time + duration))
|
||||
if start_time is None:
|
||||
if random:
|
||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||
else:
|
||||
start_time = start
|
||||
|
||||
if start_time + duration > stop:
|
||||
return example
|
||||
|
||||
start_index = arrays.get_coord_index(
|
||||
example, # type: ignore
|
||||
"time",
|
||||
start_time,
|
||||
)
|
||||
|
||||
end_index = start_index + width - 1
|
||||
|
||||
start_time = example.time.values[start_index]
|
||||
end_time = example.time.values[end_index]
|
||||
|
||||
return example.sel(
|
||||
time=slice(start_time, end_time),
|
||||
audio_time=slice(start_time, end_time + step),
|
||||
)
|
||||
|
||||
|
||||
def combine_audio(
|
||||
audio1: xr.DataArray,
|
||||
audio2: xr.DataArray,
|
||||
alpha: Optional[float] = None,
|
||||
min_alpha: float = 0.3,
|
||||
max_alpha: float = 0.7,
|
||||
) -> xr.DataArray:
|
||||
class MixAugmentationConfig(BaseAugmentationConfig):
|
||||
min_weight: float = 0.3
|
||||
max_weight: float = 0.7
|
||||
|
||||
|
||||
def mix_examples(
|
||||
example: xr.Dataset,
|
||||
other: xr.Dataset,
|
||||
weight: Optional[float] = None,
|
||||
min_weight: float = 0.3,
|
||||
max_weight: float = 0.7,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Combine two audio clips."""
|
||||
config = config or PreprocessingConfig()
|
||||
|
||||
if alpha is None:
|
||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||
if weight is None:
|
||||
weight = np.random.uniform(min_weight, max_weight)
|
||||
|
||||
return alpha * audio1 + (1 - alpha) * audio2.data
|
||||
audio1 = example["audio"]
|
||||
audio2 = adjust_width(other["audio"].values, len(audio1))
|
||||
|
||||
combined = weight * audio1 + (1 - weight) * audio2
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
combined.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
).data
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
# simplification in the subclip process: It doesn't account for the
|
||||
# spectrogram parameters to precisely determine the corresponding audio
|
||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||
# as needed.
|
||||
previous_width = len(example["time"])
|
||||
spectrogram = adjust_width(spectrogram, previous_width)
|
||||
|
||||
detection_heatmap = xr.apply_ufunc(
|
||||
np.maximum,
|
||||
example["detection"],
|
||||
adjust_width(other["detection"].values, previous_width),
|
||||
)
|
||||
|
||||
class_heatmap = xr.apply_ufunc(
|
||||
np.maximum,
|
||||
example["class"],
|
||||
adjust_width(other["class"].values, previous_width),
|
||||
)
|
||||
|
||||
size_heatmap = example["size"] + adjust_width(
|
||||
other["size"].values, previous_width
|
||||
)
|
||||
|
||||
return xr.Dataset(
|
||||
{
|
||||
"audio": combined,
|
||||
"spectrogram": xr.DataArray(
|
||||
data=spectrogram,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
"size": size_heatmap,
|
||||
},
|
||||
attrs=example.attrs,
|
||||
)
|
||||
|
||||
|
||||
# def random_mix(
|
||||
# audio: xr.DataArray,
|
||||
# clip: data.ClipAnnotation,
|
||||
# provider: Optional[ClipProvider] = None,
|
||||
# alpha: Optional[float] = None,
|
||||
# min_alpha: float = 0.3,
|
||||
# max_alpha: float = 0.7,
|
||||
# join_annotations: bool = True,
|
||||
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
# """Mix two audio clips."""
|
||||
# if provider is None:
|
||||
# raise ValueError("No audio provider given.")
|
||||
#
|
||||
# try:
|
||||
# other_audio, other_clip = provider(clip)
|
||||
# except (StopIteration, ValueError):
|
||||
# raise ValueError("No more audio sources available.")
|
||||
#
|
||||
# new_audio = combine_audio(
|
||||
# audio,
|
||||
# other_audio,
|
||||
# alpha=alpha,
|
||||
# min_alpha=min_alpha,
|
||||
# max_alpha=max_alpha,
|
||||
# )
|
||||
#
|
||||
# if join_annotations:
|
||||
# clip = clip.model_copy(
|
||||
# update=dict(
|
||||
# sound_events=clip.sound_events + other_clip.sound_events,
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# return new_audio, clip
|
||||
class EchoAugmentationConfig(BaseAugmentationConfig):
|
||||
max_delay: float = 0.005
|
||||
min_weight: float = 0.0
|
||||
max_weight: float = 1.0
|
||||
|
||||
|
||||
def add_echo(
|
||||
train_example: xr.Dataset,
|
||||
example: xr.Dataset,
|
||||
delay: Optional[float] = None,
|
||||
alpha: Optional[float] = None,
|
||||
min_alpha: float = 0.0,
|
||||
max_alpha: float = 1.0,
|
||||
max_delay: float = MAX_DELAY,
|
||||
weight: Optional[float] = None,
|
||||
min_weight: float = 0.1,
|
||||
max_weight: float = 1.0,
|
||||
max_delay: float = 0.005,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Add a delay to the audio."""
|
||||
config = config or PreprocessingConfig()
|
||||
|
||||
if delay is None:
|
||||
delay = np.random.uniform(0, max_delay)
|
||||
|
||||
if alpha is None:
|
||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||
if weight is None:
|
||||
weight = np.random.uniform(min_weight, max_weight)
|
||||
|
||||
spec = train_example["spectrogram"]
|
||||
audio = example["audio"]
|
||||
step = arrays.get_dim_step(audio, "audio_time")
|
||||
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||
audio = audio + weight * audio_delay
|
||||
|
||||
time_coords = spec.coords["time"]
|
||||
start_time = time_coords.attrs["min"]
|
||||
end_time = time_coords.attrs["max"]
|
||||
step = (end_time - start_time) / time_coords.size
|
||||
spectrogram = compute_spectrogram(
|
||||
audio.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
).data
|
||||
|
||||
spec_delay = spec.shift(time=int(delay / step), fill_value=0)
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
# simplification in the subclip process: It doesn't account for the
|
||||
# spectrogram parameters to precisely determine the corresponding audio
|
||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||
# as needed.
|
||||
spectrogram = adjust_width(
|
||||
spectrogram,
|
||||
example["spectrogram"].sizes["time"],
|
||||
)
|
||||
|
||||
return train_example.assign(spectrogram=spec + alpha * spec_delay)
|
||||
return example.assign(
|
||||
audio=audio,
|
||||
spectrogram=xr.DataArray(
|
||||
data=spectrogram,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class VolumeAugmentationConfig(BaseAugmentationConfig):
|
||||
min_scaling: float = 0.0
|
||||
max_scaling: float = 2.0
|
||||
|
||||
|
||||
def scale_volume(
|
||||
train_example: xr.Dataset,
|
||||
example: xr.Dataset,
|
||||
factor: Optional[float] = None,
|
||||
max_scaling: float = 2,
|
||||
min_scaling: float = 0,
|
||||
@ -139,106 +217,227 @@ def scale_volume(
|
||||
if factor is None:
|
||||
factor = np.random.uniform(min_scaling, max_scaling)
|
||||
|
||||
return train_example.assign(
|
||||
spectrogram=train_example["spectrogram"] * factor
|
||||
)
|
||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||
|
||||
|
||||
class WarpAugmentationConfig(BaseAugmentationConfig):
|
||||
delta: float = 0.04
|
||||
|
||||
|
||||
def warp_spectrogram(
|
||||
train_example: xr.Dataset,
|
||||
example: xr.Dataset,
|
||||
factor: Optional[float] = None,
|
||||
delta: float = STRETCH_SQUEEZE_DELTA,
|
||||
delta: float = 0.04,
|
||||
) -> xr.Dataset:
|
||||
"""Warp a spectrogram."""
|
||||
if factor is None:
|
||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||
|
||||
time_coords = train_example.coords["time"]
|
||||
start_time = time_coords.attrs["min"]
|
||||
end_time = time_coords.attrs["max"]
|
||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||
duration = end_time - start_time
|
||||
|
||||
new_time = np.linspace(
|
||||
start_time,
|
||||
start_time + duration * factor,
|
||||
train_example.time.size,
|
||||
example.time.size,
|
||||
)
|
||||
|
||||
return train_example.interp(time=new_time)
|
||||
spectrogram = (
|
||||
example["spectrogram"]
|
||||
.interp(
|
||||
coords={"time": new_time},
|
||||
method="linear",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
.clip(min=0)
|
||||
)
|
||||
|
||||
detection = example["detection"].interp(
|
||||
time=new_time,
|
||||
method="nearest",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
|
||||
classification = example["class"].interp(
|
||||
time=new_time,
|
||||
method="nearest",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
|
||||
size = example["size"].interp(
|
||||
time=new_time,
|
||||
method="nearest",
|
||||
kwargs=dict(
|
||||
fill_value=0,
|
||||
),
|
||||
)
|
||||
|
||||
return example.assign(
|
||||
{
|
||||
"time": new_time,
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection,
|
||||
"class": classification,
|
||||
"size": size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def mask_axis(
|
||||
train_example: xr.Dataset,
|
||||
array: xr.DataArray,
|
||||
dim: str,
|
||||
start: float,
|
||||
end: float,
|
||||
mask_all: bool = False,
|
||||
mask_value: float = 0,
|
||||
) -> xr.Dataset:
|
||||
if dim not in train_example.dims:
|
||||
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
|
||||
) -> xr.DataArray:
|
||||
if dim not in array.dims:
|
||||
raise ValueError(f"Axis {dim} not found in array")
|
||||
|
||||
coord = train_example.coords[dim]
|
||||
coord = array.coords[dim]
|
||||
condition = (coord < start) | (coord > end)
|
||||
|
||||
if mask_all:
|
||||
return train_example.where(condition, other=mask_value)
|
||||
if callable(mask_value):
|
||||
mask_value = mask_value(array)
|
||||
|
||||
return train_example.assign(
|
||||
spectrogram=train_example.spectrogram.where(
|
||||
condition, other=mask_value
|
||||
)
|
||||
)
|
||||
return array.where(condition, other=mask_value)
|
||||
|
||||
|
||||
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
max_perc: float = 0.05
|
||||
max_masks: int = 3
|
||||
|
||||
|
||||
def mask_time(
|
||||
train_example: xr.Dataset,
|
||||
max_time_mask: float = MASK_MAX_TIME_PERC,
|
||||
max_num_masks: int = 3,
|
||||
example: xr.Dataset,
|
||||
max_perc: float = 0.05,
|
||||
max_mask: int = 3,
|
||||
) -> xr.Dataset:
|
||||
"""Mask a random section of the time axis."""
|
||||
num_masks = np.random.randint(1, max_mask + 1)
|
||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||
|
||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||
|
||||
time_coord = train_example.coords["time"]
|
||||
start_time = time_coord.attrs.get("min", time_coord.min())
|
||||
end_time = time_coord.attrs.get("max", time_coord.max())
|
||||
|
||||
spectrogram = example["spectrogram"]
|
||||
for _ in range(num_masks):
|
||||
mask_size = np.random.uniform(0, max_time_mask)
|
||||
mask_size = np.random.uniform(0, max_perc) * (end_time - start_time)
|
||||
start = np.random.uniform(start_time, end_time - mask_size)
|
||||
end = start + mask_size
|
||||
train_example = mask_axis(train_example, "time", start, end)
|
||||
spectrogram = mask_axis(spectrogram, "time", start, end)
|
||||
|
||||
return train_example
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
max_perc: float = 0.10
|
||||
max_masks: int = 3
|
||||
|
||||
|
||||
def mask_frequency(
|
||||
train_example: xr.Dataset,
|
||||
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
||||
max_num_masks: int = 3,
|
||||
example: xr.Dataset,
|
||||
max_perc: float = 0.10,
|
||||
max_masks: int = 3,
|
||||
) -> xr.Dataset:
|
||||
"""Mask a random section of the frequency axis."""
|
||||
num_masks = np.random.randint(1, max_masks + 1)
|
||||
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
|
||||
|
||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||
|
||||
freq_coord = train_example.coords["frequency"]
|
||||
min_freq = freq_coord.min()
|
||||
max_freq = freq_coord.max()
|
||||
|
||||
spectrogram = example["spectrogram"]
|
||||
for _ in range(num_masks):
|
||||
mask_size = np.random.uniform(0, max_freq_mask)
|
||||
mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq)
|
||||
start = np.random.uniform(min_freq, max_freq - mask_size)
|
||||
end = start + mask_size
|
||||
train_example = mask_axis(train_example, "frequency", start, end)
|
||||
spectrogram = mask_axis(spectrogram, "frequency", start, end)
|
||||
|
||||
return train_example
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
AUGMENTATIONS: List[Augmentation] = [
|
||||
select_random_subclip,
|
||||
add_echo,
|
||||
scale_volume,
|
||||
mask_time,
|
||||
mask_frequency,
|
||||
]
|
||||
class AugmentationsConfig(BaseConfig):
|
||||
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
||||
echo: EchoAugmentationConfig = Field(
|
||||
default_factory=EchoAugmentationConfig
|
||||
)
|
||||
volume: VolumeAugmentationConfig = Field(
|
||||
default_factory=VolumeAugmentationConfig
|
||||
)
|
||||
warp: WarpAugmentationConfig = Field(
|
||||
default_factory=WarpAugmentationConfig
|
||||
)
|
||||
time_mask: TimeMaskAugmentationConfig = Field(
|
||||
default_factory=TimeMaskAugmentationConfig
|
||||
)
|
||||
frequency_mask: FrequencyMaskAugmentationConfig = Field(
|
||||
default_factory=FrequencyMaskAugmentationConfig
|
||||
)
|
||||
|
||||
|
||||
def load_agumentation_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> AugmentationsConfig:
|
||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||
|
||||
|
||||
def should_apply(config: BaseAugmentationConfig) -> bool:
|
||||
if not config.enable:
|
||||
return False
|
||||
|
||||
return np.random.uniform() < config.probability
|
||||
|
||||
|
||||
def augment_example(
|
||||
example: xr.Dataset,
|
||||
config: AugmentationsConfig,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
others: Optional[Callable[[], xr.Dataset]] = None,
|
||||
) -> xr.Dataset:
|
||||
if should_apply(config.mix) and (others is not None):
|
||||
other = others()
|
||||
example = mix_examples(
|
||||
example,
|
||||
other,
|
||||
min_weight=config.mix.min_weight,
|
||||
max_weight=config.mix.max_weight,
|
||||
config=preprocessing_config,
|
||||
)
|
||||
|
||||
if should_apply(config.echo):
|
||||
example = add_echo(
|
||||
example,
|
||||
max_delay=config.echo.max_delay,
|
||||
min_weight=config.echo.min_weight,
|
||||
max_weight=config.echo.max_weight,
|
||||
config=preprocessing_config,
|
||||
)
|
||||
|
||||
if should_apply(config.volume):
|
||||
example = scale_volume(
|
||||
example,
|
||||
max_scaling=config.volume.max_scaling,
|
||||
min_scaling=config.volume.min_scaling,
|
||||
)
|
||||
|
||||
if should_apply(config.warp):
|
||||
example = warp_spectrogram(
|
||||
example,
|
||||
delta=config.warp.delta,
|
||||
)
|
||||
|
||||
if should_apply(config.time_mask):
|
||||
example = mask_time(
|
||||
example,
|
||||
max_perc=config.time_mask.max_perc,
|
||||
max_mask=config.time_mask.max_masks,
|
||||
)
|
||||
|
||||
if should_apply(config.frequency_mask):
|
||||
example = mask_frequency(
|
||||
example,
|
||||
max_perc=config.frequency_mask.max_perc,
|
||||
max_masks=config.frequency_mask.max_masks,
|
||||
)
|
||||
|
||||
return example
|
||||
|
31
batdetect2/train/config.py
Normal file
31
batdetect2/train/config.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
"OptimizerConfig",
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
]
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
batch_size: int = 32
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
@ -1,12 +1,21 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union
|
||||
from typing import NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.tensors import adjust_width
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
augment_example,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessingConfig
|
||||
|
||||
__all__ = [
|
||||
@ -26,74 +35,113 @@ class TrainExample(NamedTuple):
|
||||
idx: torch.Tensor
|
||||
|
||||
|
||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
class SubclipConfig(BaseConfig):
|
||||
duration: Optional[float] = None
|
||||
width: int = 512
|
||||
random: bool = False
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||
augmentation: AugmentationsConfig = Field(
|
||||
default_factory=AugmentationsConfig
|
||||
)
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
filenames: Sequence[PathLike],
|
||||
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
self.filenames = filenames
|
||||
self.transform = transform
|
||||
self.subclip = subclip
|
||||
self.augmentation = augmentation
|
||||
self.preprocessing = preprocessing or PreprocessingConfig()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
data = self.load(idx)
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
|
||||
if self.augmentation:
|
||||
dataset = augment_example(
|
||||
dataset,
|
||||
self.augmentation,
|
||||
preprocessing_config=self.preprocessing,
|
||||
others=self.get_random_example,
|
||||
)
|
||||
|
||||
return TrainExample(
|
||||
spec=data["spectrogram"],
|
||||
detection_heatmap=data["detection"],
|
||||
class_heatmap=data["class"],
|
||||
size_heatmap=data["size"],
|
||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||
detection_heatmap=self.to_tensor(dataset["detection"]),
|
||||
class_heatmap=self.to_tensor(dataset["class"]),
|
||||
size_heatmap=self.to_tensor(dataset["size"]),
|
||||
idx=torch.tensor(idx),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
||||
return cls(get_files(directory, extension))
|
||||
def from_directory(
|
||||
cls,
|
||||
directory: PathLike,
|
||||
extension: str = ".nc",
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
return cls(
|
||||
get_files(directory, extension),
|
||||
subclip=subclip,
|
||||
augmentation=augmentation,
|
||||
preprocessing=preprocessing,
|
||||
)
|
||||
|
||||
def load(self, idx) -> Dict[str, torch.Tensor]:
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
idx = np.random.randint(0, len(self))
|
||||
dataset = self.get_dataset(idx)
|
||||
return {
|
||||
"spectrogram": torch.tensor(
|
||||
dataset["spectrogram"].values
|
||||
).unsqueeze(0),
|
||||
"detection": torch.tensor(dataset["detection"].values),
|
||||
"class": torch.tensor(dataset["class"].values),
|
||||
"size": torch.tensor(dataset["size"].values),
|
||||
}
|
||||
|
||||
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset:
|
||||
if self.transform is not None:
|
||||
return self.transform(dataset)
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
def get_dataset(self, idx):
|
||||
def get_dataset(self, idx) -> xr.Dataset:
|
||||
return xr.open_dataset(self.filenames[idx])
|
||||
|
||||
def get_spectrogram(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||
return data.ClipAnnotation.model_validate_json(
|
||||
self.get_dataset(idx).attrs["clip_annotation"]
|
||||
)
|
||||
|
||||
def get_detection_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["detection"]
|
||||
def to_tensor(
|
||||
self,
|
||||
array: xr.DataArray,
|
||||
dtype=np.float32,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.tensor(array.values.astype(dtype))
|
||||
|
||||
def get_class_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["class"]
|
||||
if not self.subclip:
|
||||
return tensor
|
||||
|
||||
def get_size_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["size"]
|
||||
width = self.subclip.width
|
||||
return adjust_width(tensor, width)
|
||||
|
||||
def get_clip_annotation(self, idx):
|
||||
filename = self.filenames[idx]
|
||||
dataset = xr.open_dataset(filename)
|
||||
clip_annotation = dataset.attrs["clip_annotation"]
|
||||
return data.ClipAnnotation.model_validate_json(clip_annotation)
|
||||
|
||||
def get_preprocessing_configuration(self, idx):
|
||||
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
|
||||
return PreprocessingConfig.model_validate_json(config)
|
||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
@ -1,27 +1,43 @@
|
||||
from typing import Tuple
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import data, geometry, arrays
|
||||
from soundevent import arrays, data, geometry
|
||||
from soundevent.geometry.operations import Positions
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"ClassMapper",
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"generate_heatmaps",
|
||||
"load_label_config",
|
||||
]
|
||||
|
||||
|
||||
TARGET_SIGMA = 3.0
|
||||
class HeatmapsConfig(BaseConfig):
|
||||
position: Positions = "bottom-left"
|
||||
sigma: float = 3.0
|
||||
time_scale: float = 1000.0
|
||||
frequency_scale: float = 1 / 859.375
|
||||
|
||||
|
||||
class LabelConfig(BaseConfig):
|
||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
class_mapper: ClassMapper,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
class_names: List[str],
|
||||
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
|
||||
target_sigma: float = 3.0,
|
||||
position: Positions = "bottom-left",
|
||||
time_scale: float = 1000.0,
|
||||
frequency_scale: float = 1 / 859.375,
|
||||
dtype=np.float32,
|
||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
@ -31,20 +47,13 @@ def generate_heatmaps(
|
||||
"Spectrogram must have time and frequency dimensions."
|
||||
)
|
||||
|
||||
time_duration = arrays.get_dim_width(spec, dim="time")
|
||||
freq_bandwidth = arrays.get_dim_width(spec, dim="frequency")
|
||||
|
||||
# Compute the size factors
|
||||
time_scale = 1 / time_duration
|
||||
frequency_scale = 1 / freq_bandwidth
|
||||
|
||||
# Initialize heatmaps
|
||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||
class_heatmap = xr.DataArray(
|
||||
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
||||
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
||||
dims=["category", *spec.dims],
|
||||
coords={
|
||||
"category": class_mapper.class_labels,
|
||||
"category": [*class_names],
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
@ -57,9 +66,8 @@ def generate_heatmaps(
|
||||
},
|
||||
)
|
||||
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
for sound_event_annotation in sound_events:
|
||||
geom = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
continue
|
||||
|
||||
@ -67,23 +75,29 @@ def generate_heatmaps(
|
||||
time, frequency = geometry.get_geometry_point(geom, position=position)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
detection_heatmap = arrays.set_value_at_pos(
|
||||
detection_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
try:
|
||||
detection_heatmap = arrays.set_value_at_pos(
|
||||
detection_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
continue
|
||||
|
||||
# Set the size of the sound event at the position in the size heatmap
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(end_time - start_time) * time_scale,
|
||||
(high_freq - low_freq) * frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
size,
|
||||
@ -92,14 +106,12 @@ def generate_heatmaps(
|
||||
)
|
||||
|
||||
# Get the class name of the sound event
|
||||
class_name = class_mapper.transform(sound_event_annotation)
|
||||
class_name = encoder(sound_event_annotation.tags)
|
||||
|
||||
if class_name is None:
|
||||
# If the label is None skip the sound event
|
||||
continue
|
||||
|
||||
# Set 1.0 at the position and category of the sound event in the class
|
||||
# heatmap
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
@ -130,3 +142,9 @@ def generate_heatmaps(
|
||||
).fillna(0.0)
|
||||
|
||||
return detection_heatmap, class_heatmap, size_heatmap
|
||||
|
||||
|
||||
def load_label_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> LabelConfig:
|
||||
return load_config(path, schema=LabelConfig, field=field)
|
82
batdetect2/train/legacy/train.py
Normal file
82
batdetect2/train/legacy/train.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
|
||||
class TrainInputs(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: LabeledDataset[TrainInputs],
|
||||
validation_dataset: LabeledDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
num_epochs * len(train_loader),
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
device,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
|
||||
def train_single_epoch(
|
||||
model: DetectionModel,
|
||||
train_loader: DataLoader,
|
||||
optimizer: Adam,
|
||||
device: torch.device,
|
||||
scheduler: CosineAnnealingLR,
|
||||
):
|
||||
model.train()
|
||||
train_loss = tu.AverageMeter()
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
||||
|
||||
|
||||
def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
||||
@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
|
||||
def split_same(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
|
||||
|
||||
|
||||
def standardize_low_freq(
|
||||
data: List[types.FileAnnotation], class_of_interest: str,
|
||||
data: List[types.FileAnnotation],
|
||||
class_of_interest: str,
|
||||
) -> List[types.FileAnnotation]:
|
||||
# address the issue of highly variable low frequency annotations
|
||||
# this often happens for contstant frequency calls
|
@ -1,56 +0,0 @@
|
||||
import pytorch_lightning as L
|
||||
from torch import Tensor, optim
|
||||
|
||||
from batdetect2.models.typing import DetectionModel, ModelOutput
|
||||
from batdetect2.train import losses
|
||||
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LitDetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class LitDetectorModel(L.LightningModule):
|
||||
model: DetectionModel
|
||||
|
||||
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
) -> Tensor:
|
||||
detection_loss = losses.focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
)
|
||||
|
||||
size_loss = losses.bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = losses.focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
valid_mask=valid_mask,
|
||||
)
|
||||
|
||||
return detection_loss + size_loss + classification_loss
|
||||
|
||||
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
|
||||
outputs: ModelOutput = self.model(batch.spec)
|
||||
loss = self.compute_loss(outputs, batch)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||
return [optimizer], [scheduler]
|
@ -1,7 +1,23 @@
|
||||
from typing import Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
__all__ = [
|
||||
"bbox_size_loss",
|
||||
"compute_loss",
|
||||
"focal_loss",
|
||||
"mse_loss",
|
||||
]
|
||||
|
||||
|
||||
class SizeLossConfig(BaseConfig):
|
||||
weight: float = 0.1
|
||||
|
||||
|
||||
def bbox_size_loss(
|
||||
@ -17,6 +33,11 @@ def bbox_size_loss(
|
||||
)
|
||||
|
||||
|
||||
class FocalLossConfig(BaseConfig):
|
||||
beta: float = 4
|
||||
alpha: float = 2
|
||||
|
||||
|
||||
def focal_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
@ -44,7 +65,7 @@ def focal_loss(
|
||||
)
|
||||
|
||||
if weights is not None:
|
||||
pos_loss = pos_loss * weights
|
||||
pos_loss = pos_loss * torch.tensor(weights)
|
||||
# neg_loss = neg_loss*weights
|
||||
|
||||
if valid_mask is not None:
|
||||
@ -75,3 +96,71 @@ def mse_loss(
|
||||
else:
|
||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
return op
|
||||
|
||||
|
||||
class DetectionLossConfig(BaseConfig):
|
||||
weight: float = 1.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
|
||||
|
||||
class ClassificationLossConfig(BaseConfig):
|
||||
weight: float = 2.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
class_weights: Optional[list[float]] = None
|
||||
|
||||
|
||||
class LossConfig(BaseConfig):
|
||||
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
||||
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
||||
classification: ClassificationLossConfig = Field(
|
||||
default_factory=ClassificationLossConfig
|
||||
)
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
|
||||
|
||||
def compute_loss(
|
||||
batch: TrainExample,
|
||||
outputs: ModelOutput,
|
||||
conf: LossConfig,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
) -> Losses:
|
||||
detection_loss = focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
beta=conf.detection.focal.beta,
|
||||
alpha=conf.detection.focal.alpha,
|
||||
)
|
||||
|
||||
size_loss = bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
weights=class_weights,
|
||||
valid_mask=valid_mask,
|
||||
beta=conf.classification.focal.beta,
|
||||
alpha=conf.classification.focal.alpha,
|
||||
)
|
||||
|
||||
total = (
|
||||
detection_loss * conf.detection.weight
|
||||
+ size_loss * conf.size.weight
|
||||
+ classification_loss * conf.classification.weight
|
||||
)
|
||||
|
||||
return Losses(
|
||||
detection=detection_loss,
|
||||
size=size_loss,
|
||||
classification=classification_loss,
|
||||
total=total,
|
||||
)
|
||||
|
@ -1,20 +1,28 @@
|
||||
"""Module for preprocessing data for training."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
from tqdm.auto import tqdm
|
||||
from multiprocessing import Pool
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
||||
from batdetect2.data.preprocessing import (
|
||||
preprocess_audio_clip,
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
compute_spectrogram,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_encoder,
|
||||
build_sound_event_filter,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
@ -22,31 +30,76 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"preprocess_single_annotation",
|
||||
"generate_train_example",
|
||||
"TrainPreprocessingConfig",
|
||||
]
|
||||
|
||||
|
||||
class TrainPreprocessingConfig(BaseConfig):
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
class_mapper: ClassMapper,
|
||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a training example."""
|
||||
spectrogram = preprocess_audio_clip(
|
||||
clip_annotation.clip,
|
||||
config=preprocessing_config,
|
||||
config = TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing_config or PreprocessingConfig(),
|
||||
target=target_config or TargetConfig(),
|
||||
labels=label_config or LabelConfig(),
|
||||
)
|
||||
|
||||
wave = load_clip_audio(
|
||||
clip_annotation.clip,
|
||||
config=config.preprocessing.audio,
|
||||
)
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
wave,
|
||||
config=config.preprocessing.spectrogram,
|
||||
)
|
||||
|
||||
filter_fn = build_sound_event_filter(
|
||||
include=config.target.include,
|
||||
exclude=config.target.exclude,
|
||||
)
|
||||
|
||||
selected_events = [
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
|
||||
encoder = build_encoder(
|
||||
config.target.classes,
|
||||
replacement_rules=config.target.replace,
|
||||
)
|
||||
class_names = get_class_names(config.target.classes)
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation,
|
||||
selected_events,
|
||||
spectrogram,
|
||||
class_mapper,
|
||||
target_sigma=target_sigma,
|
||||
class_names,
|
||||
encoder,
|
||||
target_sigma=config.labels.heatmaps.sigma,
|
||||
position=config.labels.heatmaps.position,
|
||||
time_scale=config.labels.heatmaps.time_scale,
|
||||
frequency_scale=config.labels.heatmaps.frequency_scale,
|
||||
)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
{
|
||||
# NOTE: Need to rename the time dimension to avoid conflicts with
|
||||
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||
# the spectrogram and the heatmaps to the same temporal resolution
|
||||
# as the waveform.
|
||||
"audio": wave.rename({"time": "audio_time"}),
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
@ -56,9 +109,12 @@ def generate_train_example(
|
||||
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
preprocessing_configuration=preprocessing_config.model_dump_json(),
|
||||
target_sigma=target_sigma,
|
||||
clip_annotation=clip_annotation.model_dump_json(),
|
||||
config=config.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude_defaults=True,
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -77,36 +133,54 @@ def save_to_file(
|
||||
)
|
||||
|
||||
|
||||
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
||||
"""Load configuration from file."""
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_file():
|
||||
warnings.warn(f"Config file not found: {path}. Using default config.")
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
try:
|
||||
return PreprocessingConfig.model_validate_json(path.read_text())
|
||||
except ValueError as e:
|
||||
warnings.warn(
|
||||
f"Failed to load config file: {e}. Using default config."
|
||||
)
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
return f"{clip_annotation.uuid}.nc"
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
with Pool(max_workers) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
pool.imap_unordered(
|
||||
partial(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
config: PreprocessingConfig,
|
||||
class_mapper: ClassMapper,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
) -> None:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
@ -119,50 +193,16 @@ def preprocess_single_annotation(
|
||||
if path.is_file() and replace:
|
||||
path.unlink()
|
||||
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
class_mapper,
|
||||
preprocessing_config=config,
|
||||
target_sigma=target_sigma,
|
||||
)
|
||||
try:
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||
) from error
|
||||
|
||||
save_to_file(sample, path)
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
class_mapper: ClassMapper,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if config is None:
|
||||
config = PreprocessingConfig()
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
with Pool(max_workers) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
pool.imap_unordered(
|
||||
partial(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
config=config,
|
||||
class_mapper=class_mapper,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
target_sigma=target_sigma,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
)
|
||||
)
|
||||
|
181
batdetect2/train/targets.py
Normal file
181
batdetect2/train/targets.py
Normal file
@ -0,0 +1,181 @@
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"load_target_config",
|
||||
"build_encoder",
|
||||
"build_decoder",
|
||||
"filter_sound_event",
|
||||
]
|
||||
|
||||
|
||||
class ReplaceConfig(BaseConfig):
|
||||
"""Configuration for replacing tags."""
|
||||
|
||||
original: TagInfo
|
||||
replacement: TagInfo
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Configuration for target generation."""
|
||||
|
||||
classes: List[TagInfo] = Field(
|
||||
default_factory=lambda: [
|
||||
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
|
||||
]
|
||||
)
|
||||
generic_class: Optional[TagInfo] = Field(
|
||||
default_factory=lambda: TagInfo(key="class", value="Bat")
|
||||
)
|
||||
|
||||
include: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
|
||||
)
|
||||
|
||||
exclude: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [
|
||||
TagInfo(key="class", value=""),
|
||||
TagInfo(key="class", value=" "),
|
||||
TagInfo(key="class", value="Unknown"),
|
||||
]
|
||||
)
|
||||
|
||||
replace: Optional[List[ReplaceConfig]] = None
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
include: Optional[List[TagInfo]] = None,
|
||||
exclude: Optional[List[TagInfo]] = None,
|
||||
) -> Callable[[data.SoundEventAnnotation], bool]:
|
||||
include_tags = (
|
||||
{get_tag_from_info(tag) for tag in include} if include else None
|
||||
)
|
||||
exclude_tags = (
|
||||
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
||||
)
|
||||
return partial(
|
||||
filter_sound_event,
|
||||
include=include_tags,
|
||||
exclude=exclude_tags,
|
||||
)
|
||||
|
||||
|
||||
def get_tag_label(tag_info: TagInfo) -> str:
|
||||
return tag_info.label if tag_info.label else tag_info.value
|
||||
|
||||
|
||||
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||
return sorted({get_tag_label(tag) for tag in classes})
|
||||
|
||||
|
||||
def build_replacer(
|
||||
rules: List[ReplaceConfig],
|
||||
) -> Callable[[data.Tag], data.Tag]:
|
||||
mapping = {
|
||||
get_tag_from_info(rule.original): get_tag_from_info(rule.replacement)
|
||||
for rule in rules
|
||||
}
|
||||
|
||||
def replacer(tag: data.Tag) -> data.Tag:
|
||||
return mapping.get(tag, tag)
|
||||
|
||||
return replacer
|
||||
|
||||
|
||||
def build_encoder(
|
||||
classes: List[TagInfo],
|
||||
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
||||
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||
|
||||
tag_mapping = {
|
||||
tag: get_tag_label(tag_info)
|
||||
for tag, tag_info in zip(target_tags, classes)
|
||||
}
|
||||
|
||||
replacer = (
|
||||
build_replacer(replacement_rules) if replacement_rules else lambda x: x
|
||||
)
|
||||
|
||||
def encoder(
|
||||
tags: Iterable[data.Tag],
|
||||
) -> Optional[str]:
|
||||
sanitized_tags = {replacer(tag) for tag in tags}
|
||||
|
||||
intersection = sanitized_tags & target_tags
|
||||
|
||||
if not intersection:
|
||||
return None
|
||||
|
||||
first = intersection.pop()
|
||||
return tag_mapping[first]
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build_decoder(
|
||||
classes: List[TagInfo],
|
||||
) -> Callable[[str], List[data.Tag]]:
|
||||
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||
tag_mapping = {
|
||||
get_tag_label(tag_info): tag
|
||||
for tag, tag_info in zip(target_tags, classes)
|
||||
}
|
||||
|
||||
def decoder(label: str) -> List[data.Tag]:
|
||||
tag = tag_mapping.get(label)
|
||||
return [tag] if tag else []
|
||||
|
||||
return decoder
|
||||
|
||||
|
||||
def filter_sound_event(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
include: Optional[Set[data.Tag]] = None,
|
||||
exclude: Optional[Set[data.Tag]] = None,
|
||||
) -> bool:
|
||||
tags = set(sound_event_annotation.tags)
|
||||
|
||||
if include is not None and not tags & include:
|
||||
return False
|
||||
|
||||
if exclude is not None and tags & exclude:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: Path, field: Optional[str] = None
|
||||
) -> TargetConfig:
|
||||
return load_config(path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
DEFAULT_SPECIES_LIST = [
|
||||
"Barbastellus barbastellus",
|
||||
"Eptesicus serotinus",
|
||||
"Myotis alcathoe",
|
||||
"Myotis bechsteinii",
|
||||
"Myotis brandtii",
|
||||
"Myotis daubentonii",
|
||||
"Myotis mystacinus",
|
||||
"Myotis nattereri",
|
||||
"Nyctalus leisleri",
|
||||
"Nyctalus noctula",
|
||||
"Pipistrellus nathusii",
|
||||
"Pipistrellus pipistrellus",
|
||||
"Pipistrellus pygmaeus",
|
||||
"Plecotus auritus",
|
||||
"Plecotus austriacus",
|
||||
"Rhinolophus ferrumequinum",
|
||||
"Rhinolophus hipposideros",
|
||||
]
|
@ -1,82 +1,68 @@
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import Trainer
|
||||
from soundevent.data import PathLike
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.data.datasets import ClipAnnotationDataset
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"TrainerConfig",
|
||||
"load_trainer_config",
|
||||
]
|
||||
|
||||
|
||||
class TrainInputs(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
class TrainerConfig(BaseConfig):
|
||||
accelerator: str = "auto"
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
check_val_every_n_epoch: int = 1
|
||||
devices: Union[str, int] = "auto"
|
||||
enable_checkpointing: bool = True
|
||||
gradient_clip_val: Optional[float] = None
|
||||
limit_train_batches: Optional[Union[int, float]] = None
|
||||
limit_test_batches: Optional[Union[int, float]] = None
|
||||
limit_val_batches: Optional[Union[int, float]] = None
|
||||
log_every_n_steps: Optional[int] = None
|
||||
max_epochs: Optional[int] = None
|
||||
min_epochs: Optional[int] = 100
|
||||
max_steps: Optional[int] = None
|
||||
min_steps: Optional[int] = None
|
||||
max_time: Optional[str] = None
|
||||
precision: Optional[str] = None
|
||||
reload_dataloaders_every_n_epochs: Optional[int] = None
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
validation_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
||||
return load_config(path, schema=TrainerConfig, field=field)
|
||||
|
||||
|
||||
def train(
|
||||
module: LightningModule,
|
||||
train_dataset: LabeledDataset,
|
||||
trainer_config: Optional[TrainerConfig] = None,
|
||||
dev_run: bool = False,
|
||||
overfit_batches: bool = False,
|
||||
profiler: Optional[str] = None,
|
||||
):
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
num_epochs * len(train_loader),
|
||||
trainer_config = trainer_config or TrainerConfig()
|
||||
trainer = Trainer(
|
||||
**trainer_config.model_dump(
|
||||
exclude_unset=True,
|
||||
exclude_none=True,
|
||||
),
|
||||
fast_dev_run=dev_run,
|
||||
overfit_batches=overfit_batches,
|
||||
profiler=profiler,
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
device,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
|
||||
def train_single_epoch(
|
||||
model: DetectionModel,
|
||||
train_loader: DataLoader,
|
||||
optimizer: Adam,
|
||||
device: torch.device,
|
||||
scheduler: CosineAnnealingLR,
|
||||
):
|
||||
model.train()
|
||||
train_loss = tu.AverageMeter()
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=module.config.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=7,
|
||||
)
|
||||
trainer.fit(module, train_dataloaders=train_loader)
|
||||
|
@ -1,13 +1,12 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
@ -15,9 +14,8 @@ try:
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
from typing import NotRequired # type: ignore
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
@ -597,8 +595,7 @@ class FeatureExtractor(Protocol):
|
||||
self,
|
||||
prediction: Prediction,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
...
|
||||
) -> float: ...
|
||||
|
||||
|
||||
class DatasetDict(TypedDict):
|
||||
|
@ -6,6 +6,8 @@ import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
from . import wavfile
|
||||
|
||||
__all__ = [
|
||||
@ -15,20 +17,44 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||
def time_to_x_coords(
|
||||
time_in_file: float,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
nfft = np.floor(window_duration * samplerate) # int() uses floor
|
||||
noverlap = np.floor(window_overlap * nfft)
|
||||
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
|
||||
|
||||
|
||||
# NOTE this is also defined in post_process
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
def x_coords_to_time(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
return ((x_pos * n_step) + n_overlap) / samplerate
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
|
||||
def x_coord_to_sample(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
x_pos = int(x_pos / resize_factor)
|
||||
return int((x_pos * n_step) + n_overlap)
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
@ -64,7 +90,7 @@ def generate_spectrogram(
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
).astype(np.float32)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
@ -74,7 +100,7 @@ def generate_spectrogram(
|
||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling * spec)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec , sampling_rate)
|
||||
spec = pcen(spec, sampling_rate)
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
@ -194,55 +220,118 @@ def load_audio(
|
||||
return sampling_rate, audio_raw
|
||||
|
||||
|
||||
def compute_spectrogram_width(
|
||||
length: int,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = int(window_duration * samplerate)
|
||||
n_overlap = int(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
width = (length - n_overlap) // n_step
|
||||
return int(width * resize_factor)
|
||||
|
||||
|
||||
def pad_audio(
|
||||
audio_raw,
|
||||
fs,
|
||||
ms,
|
||||
overlap_perc,
|
||||
resize_factor,
|
||||
divide_factor,
|
||||
fixed_width=None,
|
||||
audio: np.ndarray,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
|
||||
fixed_width: Optional[int] = None,
|
||||
):
|
||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||
# will be evenly divisible by `divide_factor`
|
||||
# Also deals with very short audio clips and fixed_width during training
|
||||
"""Pad audio to be evenly divisible by `divide_factor`.
|
||||
|
||||
# This code could be clearer, clean up
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
step = nfft - noverlap
|
||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||
spec_width = (audio_raw.shape[0] - noverlap) // step
|
||||
spec_width_rs = spec_width * resize_factor
|
||||
This function pads the audio signal with zeros to ensure that the
|
||||
generated spectrogram length will be evenly divisible by `divide_factor`.
|
||||
This is important for the model to work correctly.
|
||||
|
||||
if fixed_width is not None and spec_width < fixed_width:
|
||||
# too small
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
This `divide_factor` comes from the model architecture as it downscales
|
||||
the spectrogram by this factor, so the input must be divisible by this
|
||||
integer number.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
The audio signal.
|
||||
samplerate : int
|
||||
The sampling rate of the audio signal.
|
||||
window_size : float
|
||||
The window size in seconds used for the spectrogram computation.
|
||||
window_overlap : float
|
||||
The overlap between windows in the spectrogram computation.
|
||||
resize_factor : float
|
||||
This factor is used to resize the spectrogram after the STFT
|
||||
computation. Default is 0.5 which means that the spectrogram will be
|
||||
reduced by half. Important to take into account for the final size of
|
||||
the spectrogram.
|
||||
divide_factor : int
|
||||
The factor by which the spectrogram will be divided.
|
||||
fixed_width : int, optional
|
||||
If provided, the audio will be padded or cut so that the resulting
|
||||
spectrogram width will be equal to this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The padded audio signal.
|
||||
"""
|
||||
spec_width = compute_spectrogram_width(
|
||||
audio.shape[0],
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
if fixed_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
fixed_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
elif fixed_width is not None and spec_width > fixed_width:
|
||||
# too big
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = audio_raw[:diff]
|
||||
if spec_width < fixed_width:
|
||||
# need to be at least min_size
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
elif (
|
||||
spec_width_rs < min_size
|
||||
or (np.floor(spec_width_rs) % divide_factor) != 0
|
||||
):
|
||||
# need to be at least min_size
|
||||
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
||||
div_amt = np.maximum(1, div_amt)
|
||||
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
|
||||
diff = target_size * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
if spec_width > fixed_width:
|
||||
return audio[:target_samples]
|
||||
|
||||
return audio
|
||||
|
||||
min_width = int(divide_factor / resize_factor)
|
||||
|
||||
if spec_width < min_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
min_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
return audio_raw
|
||||
if (spec_width % divide_factor) == 0:
|
||||
return audio
|
||||
|
||||
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
|
||||
target_samples = x_coord_to_sample(
|
||||
target_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
|
||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
|
@ -8,6 +8,11 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from numpy.exceptions import AxisError
|
||||
except ImportError:
|
||||
from numpy import AxisError # type: ignore
|
||||
|
||||
import batdetect2.detector.compute_features as feats
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.utils.audio_utils as au
|
||||
@ -80,6 +85,7 @@ def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
weights_only: bool = True,
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
@ -100,7 +106,11 @@ def load_model(
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
net_params = torch.load(
|
||||
model_path,
|
||||
map_location=device,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
params = net_params["params"]
|
||||
|
||||
@ -242,7 +252,7 @@ def format_single_result(
|
||||
)
|
||||
class_name = class_names[np.argmax(class_overall)]
|
||||
annotations = get_annotations_from_preds(predictions, class_names)
|
||||
except (np.AxisError, ValueError):
|
||||
except (AxisError, ValueError):
|
||||
# No detections
|
||||
class_overall = np.zeros(len(class_names))
|
||||
class_name = "None"
|
||||
@ -399,7 +409,7 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
|
||||
def compute_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
sampling_rate: int,
|
||||
params: SpectrogramParameters,
|
||||
device: torch.device,
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
@ -617,7 +627,7 @@ def process_spectrogram(
|
||||
|
||||
def _process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
sampling_rate: int,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
@ -738,9 +748,7 @@ def process_file(
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * float(
|
||||
config.get("time_expansion", 1.0) or 1.0
|
||||
)
|
||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
|
@ -417,7 +417,9 @@ def plot_confusion_matrix(
|
||||
cm_norm = cm.sum(1)
|
||||
|
||||
valid_inds = np.where(cm_norm > 0)[0]
|
||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
cm[valid_inds, :] = (
|
||||
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
)
|
||||
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
||||
|
||||
if verbose:
|
||||
|
@ -155,9 +155,9 @@ class InteractivePlotter:
|
||||
|
||||
# draw bounding box around call
|
||||
self.ax[1].patches[0].remove()
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
||||
1.0 + 2.0 * self.spec_pad
|
||||
)
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[
|
||||
1
|
||||
] / (1.0 + 2.0 * self.spec_pad)
|
||||
xx = w_diff + self.spec_pad * spec_width_orig
|
||||
ww = spec_width_orig
|
||||
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
||||
@ -183,7 +183,9 @@ class InteractivePlotter:
|
||||
round(self.call_info[self.current_id]["start_time"], 3)
|
||||
)
|
||||
+ ", prob="
|
||||
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
||||
+ str(
|
||||
round(self.call_info[self.current_id]["det_prob"], 3)
|
||||
)
|
||||
)
|
||||
self.ax[0].set_xlabel(info_str)
|
||||
|
||||
|
@ -8,6 +8,7 @@ Functions
|
||||
`write`: Write a numpy array as a WAV file.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
@ -156,7 +157,6 @@ def read(filename, mmap=False):
|
||||
fid = open(filename, "rb")
|
||||
|
||||
try:
|
||||
|
||||
# some files seem to have the size recorded in the header greater than
|
||||
# the actual file size.
|
||||
fid.seek(0, os.SEEK_END)
|
||||
|
743
notebooks/Augmentations.ipynb
Normal file
743
notebooks/Augmentations.ipynb
Normal file
File diff suppressed because one or more lines are too long
1166
notebooks/Migrations.ipynb
Normal file
1166
notebooks/Migrations.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -6,11 +6,11 @@
|
||||
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:20.598611Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:20.596274Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:20.670888Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:20.668193Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:20.598423Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:02.699871Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:02.699590Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:02.710312Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:02.709798Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -25,11 +25,11 @@
|
||||
"id": "326c5432-94e6-4abf-a332-fe902559461b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:20.676278Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:20.675545Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.872556Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.871725Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:20.676206Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:02.711324Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:02.711067Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.092380Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.091830Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@ -37,7 +37,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
@ -45,26 +45,35 @@
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from typing import List, Optional\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from soundevent import data\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"from batdetect2.data.labels import ClassMapper\n",
|
||||
"from batdetect2.models.detectors import DetectorModel\n",
|
||||
"from batdetect2.train.modules import DetectorModel\n",
|
||||
"from batdetect2.train.augmentations import (\n",
|
||||
" add_echo,\n",
|
||||
" select_random_subclip,\n",
|
||||
" warp_spectrogram,\n",
|
||||
")\n",
|
||||
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
|
||||
"from batdetect2.train.preprocess import PreprocessingConfig"
|
||||
"from batdetect2.train.preprocess import PreprocessingConfig\n",
|
||||
"from soundevent import data\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from soundevent.types import ClassMapper\n",
|
||||
"from torch.utils.data import DataLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fa202af2-5c0d-4b5d-91a3-097ef5cd4272",
|
||||
"metadata": {},
|
||||
"id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
|
||||
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
|
||||
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
|
||||
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
|
||||
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Training Datasets"
|
||||
]
|
||||
@ -75,11 +84,11 @@
|
||||
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:25.874255Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:25.873473Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.912952Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.911844Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:25.874206Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.093487Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.092990Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.121636Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.121143Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -93,11 +102,11 @@
|
||||
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:25.914456Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:25.914027Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.954939Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.953906Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:25.914410Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.122685Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.122270Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.151386Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.150788Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -111,11 +120,11 @@
|
||||
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:25.956758Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:25.956260Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:25.997664Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:25.996074Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:25.956705Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.152327Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.152060Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.184041Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.183372Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -129,11 +138,11 @@
|
||||
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.003195Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.002783Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.054400Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.053294Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.003158Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.186393Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.186117Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.220175Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.219322Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -152,11 +161,11 @@
|
||||
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.056060Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.055706Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.103227Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.102190Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.056025Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.221653Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.221242Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.260977Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.260375Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -168,7 +177,6 @@
|
||||
" \"Myotis mystacinus\",\n",
|
||||
" \"Pipistrellus pipistrellus\",\n",
|
||||
" \"Rhinolophus ferrumequinum\",\n",
|
||||
" \"social\",\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
|
||||
@ -197,11 +205,11 @@
|
||||
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.104877Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.104538Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.159676Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.157914Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.104843Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.262337Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.261775Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.309793Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.309216Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -215,11 +223,11 @@
|
||||
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.162346Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.161885Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:26.374668Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:26.373691Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.162305Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.310695Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.310438Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.366636Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.366059Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@ -229,7 +237,6 @@
|
||||
"text": [
|
||||
"GPU available: False, used: False\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n"
|
||||
]
|
||||
}
|
||||
@ -248,11 +255,11 @@
|
||||
"id": "0b86d49d-3314-4257-94f5-f964855be385",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:26.375918Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:26.375632Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:28.829650Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:28.828219Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:26.375889Z"
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.367499Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.367242Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.811300Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.809823Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@ -261,37 +268,67 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" | Name | Type | Params\n",
|
||||
"------------------------------------------------\n",
|
||||
"0 | feature_extractor | Net2DFast | 119 K \n",
|
||||
"1 | classifier | Conv2d | 54 \n",
|
||||
"2 | bbox | Conv2d | 18 \n",
|
||||
"------------------------------------------------\n",
|
||||
" | Name | Type | Params | Mode \n",
|
||||
"--------------------------------------------------------\n",
|
||||
"0 | feature_extractor | Net2DFast | 119 K | train\n",
|
||||
"1 | classifier | Conv2d | 54 | train\n",
|
||||
"2 | bbox | Conv2d | 18 | train\n",
|
||||
"--------------------------------------------------------\n",
|
||||
"119 K Trainable params\n",
|
||||
"448 Non-trainable params\n",
|
||||
"119 K Total params\n",
|
||||
"0.480 Total estimated model params size (MB)\n"
|
||||
"0.480 Total estimated model params size (MB)\n",
|
||||
"32 Modules in train mode\n",
|
||||
"0 Modules in eval mode\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.59it/s, v_num=13]"
|
||||
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
|
||||
"class props shape torch.Size([3, 5, 128, 512])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.54it/s, v_num=13]\n"
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -301,15 +338,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:28.832222Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:28.831642Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:29.000595Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:28.998078Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:28.832157Z"
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.811729Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.811955Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.811858Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@ -319,44 +355,54 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:27:29.004279Z",
|
||||
"iopub.status.busy": "2024-07-16T00:27:29.003486Z",
|
||||
"iopub.status.idle": "2024-07-16T00:27:29.595626Z",
|
||||
"shell.execute_reply": "2024-07-16T00:27:29.594734Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:27:29.004200Z"
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.812924Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.813260Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.813104Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictions = detector.compute_clip_predictions(clip_annotation.clip)"
|
||||
"spec = detector.compute_spectrogram(clip_annotation.clip)\n",
|
||||
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"_, ax= plt.subplots(figsize=(15, 5))\n",
|
||||
"spec.plot(ax=ax, add_colorbar=False)\n",
|
||||
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-07-16T00:28:47.178783Z",
|
||||
"iopub.status.busy": "2024-07-16T00:28:47.178143Z",
|
||||
"iopub.status.idle": "2024-07-16T00:28:47.246613Z",
|
||||
"shell.execute_reply": "2024-07-16T00:28:47.245496Z",
|
||||
"shell.execute_reply.started": "2024-07-16T00:28:47.178729Z"
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.815603Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.816065Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.815894Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Num predicted soundevents: 50\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
|
||||
]
|
||||
@ -364,7 +410,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d3883c04-d91a-4d1d-b677-196c0179dde1",
|
||||
"id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@ -386,7 +432,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
140
pyproject.toml
140
pyproject.toml
@ -1,103 +1,97 @@
|
||||
[tool]
|
||||
rye = { dev-dependencies = [
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"pytest>=8.1.1",
|
||||
] }
|
||||
[tool.pdm]
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.2.2",
|
||||
]
|
||||
|
||||
[project]
|
||||
name = "batdetect2"
|
||||
version = "1.0.8"
|
||||
version = "1.1.1"
|
||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||
authors = [
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||
]
|
||||
dependencies = [
|
||||
"librosa>=0.10.1",
|
||||
"matplotlib>=3.7.1",
|
||||
"numpy>=1.23.5",
|
||||
"pandas>=1.5.3",
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"torch>=1.13.1",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"soundevent[audio,geometry,plot]>=2.0.1",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
"pytorch-lightning>=2.2.2",
|
||||
"cf-xarray>=0.9.0",
|
||||
"onnx>=1.16.0",
|
||||
"lightning[extra]>=2.2.2",
|
||||
"tensorboard>=2.16.2",
|
||||
"click>=8.1.7",
|
||||
"librosa>=0.10.1",
|
||||
"matplotlib>=3.7.1",
|
||||
"numpy>=1.23.5",
|
||||
"pandas>=1.5.3",
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.3",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
"pytorch-lightning>=2.2.2",
|
||||
"cf-xarray>=0.9.0",
|
||||
"onnx>=1.16.0",
|
||||
"lightning[extra]>=2.2.2",
|
||||
"tensorboard>=2.16.2",
|
||||
"omegaconf>=2.3.0",
|
||||
"pyyaml>=6.0.2",
|
||||
"hydra-core>=1.3.2",
|
||||
"numba>=0.60",
|
||||
]
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.9,<3.13"
|
||||
readme = "README.md"
|
||||
license = { text = "CC-by-nc-4" }
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||
]
|
||||
keywords = [
|
||||
"bat",
|
||||
"echolocation",
|
||||
"deep learning",
|
||||
"audio",
|
||||
"machine learning",
|
||||
"classification",
|
||||
"detection",
|
||||
"bat",
|
||||
"echolocation",
|
||||
"deep learning",
|
||||
"audio",
|
||||
"machine learning",
|
||||
"classification",
|
||||
"detection",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["pdm-pep517>=1.0.0"]
|
||||
build-backend = "pdm.pep517.api"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project.scripts]
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 79
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"debugpy>=1.8.8",
|
||||
"hypothesis>=6.118.7",
|
||||
"pytest>=7.2.2",
|
||||
"ruff>=0.7.3",
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"basedpyright>=1.28.4",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
target-version = "py39"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"librosa",
|
||||
"pandas",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = 79
|
||||
|
||||
[tool.pylsp-mypy]
|
||||
enabled = false
|
||||
live_mode = true
|
||||
strict = true
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||
|
||||
[tool.pydocstyle]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = [
|
||||
"bat_detect",
|
||||
"tests",
|
||||
]
|
||||
include = ["batdetect2", "tests"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
|
@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.audio_utils as au
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"audio_path", type=str, help="Input directory for audio"
|
||||
@ -65,7 +64,9 @@ if __name__ == "__main__":
|
||||
else:
|
||||
# load uk data - special case
|
||||
print("\nLoading:", args["uk_split"], "\n")
|
||||
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
|
||||
dataset_name = (
|
||||
"uk_" + args["uk_split"]
|
||||
) # should be uk_diff, or uk_same
|
||||
datasets, _ = ts.get_train_test_data(
|
||||
args["ann_file"],
|
||||
args["audio_path"],
|
||||
|
@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
||||
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
||||
@ -143,7 +142,9 @@ if __name__ == "__main__":
|
||||
|
||||
# run model and filter detections so only keep ones in relevant time range
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
|
||||
results = du.process_file(
|
||||
args_cmd["audio_file"], model, run_config, device
|
||||
)
|
||||
pred_anns = filter_anns(
|
||||
results["pred_dict"]["annotation"],
|
||||
args_cmd["start_time"],
|
||||
|
@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
||||
parser.add_argument(
|
||||
"audio_file", type=str, help="Path to input audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_path", type=str, help="Path to trained BatDetect model"
|
||||
)
|
||||
|
@ -198,7 +198,6 @@ def save_summary_image(
|
||||
)
|
||||
ii = 0
|
||||
for row in ax:
|
||||
|
||||
if type(row) != np.ndarray:
|
||||
row = np.array([row])
|
||||
|
||||
@ -215,7 +214,9 @@ def save_summary_image(
|
||||
)
|
||||
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
||||
col.set_xticks([])
|
||||
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
|
||||
col.title.set_text(
|
||||
str(ii + 1) + " " + species_names[order[ii]]
|
||||
)
|
||||
col.tick_params(axis="both", which="major", labelsize=7)
|
||||
ii += 1
|
||||
|
||||
|
109
tests/conftest.py
Normal file
109
tests/conftest.py
Normal file
@ -0,0 +1,109 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_data_dir() -> Path:
|
||||
pkg_dir = Path(__file__).parent.parent
|
||||
example_data_dir = pkg_dir / "example_data"
|
||||
assert example_data_dir.exists()
|
||||
return example_data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_dir(example_data_dir: Path) -> Path:
|
||||
example_audio_dir = example_data_dir / "audio"
|
||||
assert example_audio_dir.exists()
|
||||
return example_audio_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_anns_dir(example_data_dir: Path) -> Path:
|
||||
example_anns_dir = example_data_dir / "anns"
|
||||
assert example_anns_dir.exists()
|
||||
return example_anns_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||
assert len(audio_files) == 3
|
||||
return audio_files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir() -> Path:
|
||||
dir = Path(__file__).parent / "data"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contrib_dir(data_dir) -> Path:
|
||||
dir = data_dir / "contrib"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wav_factory(tmp_path: Path):
|
||||
def _wav_factory(
|
||||
path: Optional[Path] = None,
|
||||
duration: float = 0.3,
|
||||
channels: int = 1,
|
||||
samplerate: int = 441_000,
|
||||
bit_depth: int = 16,
|
||||
) -> Path:
|
||||
path = path or tmp_path / f"{uuid.uuid4()}.wav"
|
||||
frames = int(samplerate * duration)
|
||||
shape = (frames, channels)
|
||||
subtype = f"PCM_{bit_depth}"
|
||||
|
||||
if bit_depth == 16:
|
||||
dtype = np.int16
|
||||
elif bit_depth == 32:
|
||||
dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f"Unsupported bit depth: {bit_depth}")
|
||||
|
||||
wav = np.random.uniform(
|
||||
low=np.iinfo(dtype).min,
|
||||
high=np.iinfo(dtype).max,
|
||||
size=shape,
|
||||
).astype(dtype)
|
||||
sf.write(str(path), wav, samplerate, subtype=subtype)
|
||||
return path
|
||||
|
||||
return _wav_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recording_factory(wav_factory: Callable[..., Path]):
|
||||
def _recording_factory(
|
||||
tags: Optional[list[data.Tag]] = None,
|
||||
path: Optional[Path] = None,
|
||||
recording_id: Optional[uuid.UUID] = None,
|
||||
duration: float = 1,
|
||||
channels: int = 1,
|
||||
samplerate: int = 256_000,
|
||||
time_expansion: float = 1,
|
||||
) -> data.Recording:
|
||||
path = path or wav_factory(
|
||||
duration=duration,
|
||||
channels=channels,
|
||||
samplerate=samplerate,
|
||||
)
|
||||
return data.Recording.from_file(
|
||||
path=path,
|
||||
uuid=recording_id or uuid.uuid4(),
|
||||
time_expansion=time_expansion,
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
return _recording_factory
|
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
Binary file not shown.
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
@ -1,14 +1,13 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2 import api
|
||||
|
||||
@ -267,7 +266,6 @@ def test_process_file_with_spec_slices():
|
||||
assert len(results["spec_slices"]) == len(detections)
|
||||
|
||||
|
||||
|
||||
def test_process_file_with_empty_predictions_does_not_fail(
|
||||
tmp_path: Path,
|
||||
):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user