Compare commits

...

71 Commits

Author SHA1 Message Date
mbsantiago
98c6da6d42 Removed non needed info on datasets 2025-04-03 16:51:30 +01:00
mbsantiago
fe8d044af2 Using basedpyright 2025-04-03 16:50:58 +01:00
mbsantiago
62fa38557e Minor tweaks 2025-04-03 16:50:49 +01:00
mbsantiago
213b6dfd29 Expose model functions 2025-04-03 16:50:20 +01:00
mbsantiago
bfa6049adc Expose preprocessing functions 2025-04-03 16:50:06 +01:00
mbsantiago
e383a33cbf Improved train module 2025-04-03 16:49:58 +01:00
mbsantiago
7689580a24 Create train config module 2025-04-03 16:49:47 +01:00
mbsantiago
1338ae7431 Better read config nested field 2025-04-03 16:49:21 +01:00
mbsantiago
c2c4ac53fd Formatting 2025-04-03 16:49:11 +01:00
mbsantiago
ff00da9a9a Removed old data module 2025-04-03 16:48:50 +01:00
mbsantiago
22cf47ed39 Moved lightning module to root module 2025-04-03 16:48:31 +01:00
mbsantiago
d9f7304a0f Moved previous code to legacy folders 2025-04-03 16:48:01 +01:00
mbsantiago
451093f2da More structured data module 2025-04-03 16:47:03 +01:00
mbsantiago
30d3a2c92e Start work on expanding cli 2025-04-03 16:46:43 +01:00
mbsantiago
c17a25fa75 Remove dvc dependency 2025-04-02 10:00:57 +01:00
mbsantiago
29f7862153 remove dvc 2025-04-02 09:52:44 +01:00
mbsantiago
dbc3ff9364 Added more up to date numba version 2025-04-02 09:46:50 +01:00
mbsantiago
150305a273 Format code 2025-02-25 11:25:16 +00:00
mbsantiago
904e8f23ea Update pyright 2025-01-28 19:36:05 +00:00
mbsantiago
48e009fa9d WIP 2025-01-28 19:35:57 +00:00
mbsantiago
f7d6516550 WIP 2025-01-23 14:08:55 +00:00
mbsantiago
e9e1f7ce2f Lets goo 2025-01-05 20:29:32 +00:00
mbsantiago
113223be02 Add hydra to deps 2024-11-20 11:39:46 +00:00
mbsantiago
35c916482c Added datasets to data folder 2024-11-20 11:31:21 +00:00
mbsantiago
09bd8cf423 Added dvc 2024-11-20 11:29:14 +00:00
mbsantiago
f6cdd4e87e Starting to create dataset builders 2024-11-19 22:54:26 +00:00
mbsantiago
9cf159efff Reworking model creation 2024-11-19 19:34:54 +00:00
mbsantiago
36c90a600f Ensure train inputs are almost equal 2024-11-18 18:10:58 +00:00
mbsantiago
1f0fb14d89 Minor restructuring 2024-11-16 21:26:18 +00:00
mbsantiago
ee884da8b0 Make sure labels are working in the notebook 2024-11-16 18:23:43 +00:00
mbsantiago
6a9e33c729 Merge branch 'main' into train 2024-11-15 20:10:36 +00:00
mbsantiago
2100a3e483 Bump version: 1.1.0 → 1.1.1 2024-11-11 13:01:57 +00:00
mbsantiago
1d3cd2e305 Update lock 2024-11-11 13:01:55 +00:00
Santiago Martinez Balvanera
d5753b95bb
Merge pull request #39 from macaodha/fix/handle-empty-files-gracefully
fix: Handle Empty Audio Files Gracefully (GH-20)
2024-11-11 12:59:28 +00:00
mbsantiago
69f59ff559 Added the EOFError to the list of expected errors when processing files 2024-11-11 12:44:21 +00:00
mbsantiago
1a11174bc4 Add a test to validate that empty files are handled gracefully 2024-11-11 12:43:46 +00:00
Santiago Martinez Balvanera
c5c9476e52
Merge pull request #38 from macaodha/test/add_audio_files_provided_by_padpadpadpad_to_test_suite
test: Add failing audio files from GH-29 to contrib test suite
2024-11-11 12:23:03 +00:00
mbsantiago
270b3f212d Created test to verify no errors occurred when running on padpadpadpad recordings 2024-11-11 12:13:00 +00:00
mbsantiago
f61d1d8c72 Add audio files provided by @padpadpadpad to the contrib test files 2024-11-11 12:12:38 +00:00
Santiago Martinez Balvanera
4627ddd739
Merge pull request #37 from macaodha/fix/GH-30-torch-deprecation-warning-weights-only
fix: Address PyTorch Model Loading Deprecation Warning (GH-30)
2024-11-11 12:02:26 +00:00
mbsantiago
3477d7b5b4 Run the same test with example data instead of random audio 2024-11-11 11:57:46 +00:00
mbsantiago
394c66a2ee Added test to validate that changing model loading behaviour did not change model predictions 2024-11-11 11:46:27 +00:00
mbsantiago
d085b3212c Added weights_only argument to model loading function 2024-11-11 11:46:06 +00:00
mbsantiago
c393c5c29b Bump version: 1.0.8 → 1.1.0 2024-11-11 11:18:39 +00:00
mbsantiago
3c22ff28a7 add bump2version config 2024-11-11 11:18:37 +00:00
Santiago Martinez Balvanera
7dc28695b2
Merge pull request #36 from macaodha/fix/GH-31-negative-dimension-are-not-allowed
fix: Resolve detect Command Failure with Specific Audio Files (GH-31)
2024-11-10 22:53:45 +00:00
mbsantiago
505cca2dea Original test now passing. Issue seems to be fixed 2024-11-10 22:41:03 +00:00
mbsantiago
7906842a16 Added test to ensure pad_audio function, and utils, are working as expected 2024-11-10 22:39:30 +00:00
mbsantiago
a4b22d6590 Improve the pad_audio function
This function was the culprit of the error. Broke the function into
other helper functions to make the flow easier to follow
2024-11-10 22:39:10 +00:00
mbsantiago
25e0a53ad1 Add hypothesis to dev dependencies for easier testing 2024-11-10 22:38:13 +00:00
mbsantiago
039c002796 Remove unnecessary imports 2024-11-10 22:37:57 +00:00
mbsantiago
c97a87b2a4 Remove numba debug logging for easier debugging 2024-11-10 22:37:45 +00:00
mbsantiago
d93d8284d0 Added a test that replicates the error 2024-11-10 20:06:58 +00:00
Santiago Martinez Balvanera
697b5dbddb
Merge pull request #35 from macaodha/fix/update-python-version-support
feat: Drop Python 3.8 Support, Add Python 3.12 Support
2024-11-10 19:52:42 +00:00
mbsantiago
d5bf8f5ad8 Drop 3.9 and add 3.12 to python-version matrix in test github workflow 2024-11-10 19:46:12 +00:00
mbsantiago
fcbccbe012 Update uv lock 2024-11-10 19:45:01 +00:00
mbsantiago
4917641e2c Drop support for python 3.8 and add for 3.12 2024-11-10 19:44:57 +00:00
Santiago Martinez Balvanera
39c3918103
Merge pull request #34 from macaodha/feat/migrate-to-numpy-2
Feat/migrate to numpy 2
2024-11-10 19:17:33 +00:00
mbsantiago
1ac3808fee Remove the numpy<2 requirement from the dependencies specification 2024-11-10 19:13:44 +00:00
mbsantiago
9e0ad7fd78 address all linting errors from rule NPY201 2024-11-10 19:13:30 +00:00
mbsantiago
95bb0985e7 Add ruff rule to help migrating to numpy 2.0 2024-11-10 19:13:11 +00:00
Santiago Martinez Balvanera
cb088359ae
Merge pull request #33 from macaodha/feat/migrate-to-uv
Feat/migrate to uv
2024-11-10 19:04:47 +00:00
mbsantiago
c5030123aa Restrict pytorch version for python 3.8 compatibility 2024-11-10 18:59:35 +00:00
mbsantiago
1c1fbd8019 Added dev dependencies and updated github actions to use uv 2024-11-10 18:32:50 +00:00
mbsantiago
c65fe1c9f9 change pyproject metadata to use uv and hatch instead of pdm 2024-11-10 18:20:18 +00:00
Santiago Martinez Balvanera
d05bec880a
Merge pull request #32 from ccarrizosa/fix/np_exception
Fix numpy exception handling
2024-11-10 18:15:31 +00:00
ccarrizosa
8597ef0a1c Limit numpy versions to <2 2024-11-10 15:54:16 +01:00
ccarrizosa
2d8a7b67f8 Revert support for newest numpy versions. 2024-11-10 15:54:01 +01:00
ccarrizosa
68351d2224 Fix numpy exception handling 2024-11-09 22:35:11 +01:00
mbsantiago
3f34164028 update gitignore 2024-10-15 22:39:27 +01:00
mbsantiago
d84b7795f6 Updating training preprocess notebook 2024-10-15 22:38:47 +01:00
118 changed files with 11681 additions and 2989 deletions

8
.bumpversion.cfg Normal file
View File

@ -0,0 +1,8 @@
[bumpversion]
current_version = 1.1.1
commit = True
tag = True
[bumpversion:file:batdetect2/__init__.py]
[bumpversion:file:pyproject.toml]

View File

@ -1,34 +1,29 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: Python package name: Python package
on: on:
push: push:
branches: [ "main" ] branches: ["main"]
pull_request: pull_request:
branches: [ "main" ] branches: ["main"]
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Install uv
uses: actions/setup-python@v3 uses: astral-sh/setup-uv@v3
with: with:
python-version: ${{ matrix.python-version }} enable-cache: true
- name: Install dependencies cache-dependency-glob: "uv.lock"
run: | - name: Set up Python ${{ matrix.python-version }}
python -m pip install --upgrade pip run: uv python install ${{ matrix.python-version }}
python -m pip install pytest - name: Install the project
pip install . run: uv sync --all-extras --dev
- name: Test with pytest - name: Test with pytest
run: | run: uv run pytest
pytest

View File

@ -1,11 +1,3 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package name: Upload Python Package
on: on:
@ -17,23 +9,22 @@ permissions:
jobs: jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v3 uses: actions/setup-python@v3
with: with:
python-version: '3.x' python-version: "3.x"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install build pip install build
- name: Build package - name: Build package
run: python -m build run: python -m build
- name: Publish package - name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with: with:
user: __token__ user: __token__
password: ${{ secrets.PYPI_API_TOKEN }} password: ${{ secrets.PYPI_API_TOKEN }}

4
.gitignore vendored
View File

@ -103,13 +103,11 @@ experiments/*
.ipynb_checkpoints .ipynb_checkpoints
*.ipynb *.ipynb
# Bump2version
.bumpversion.cfg
# DO Include # DO Include
!batdetect2_notebook.ipynb !batdetect2_notebook.ipynb
!batdetect2/models/checkpoints/*.pth.tar !batdetect2/models/checkpoints/*.pth.tar
!tests/data/*.wav !tests/data/*.wav
!notebooks/*.ipynb !notebooks/*.ipynb
!tests/data/**/*.wav
notebooks/lightning_logs notebooks/lightning_logs
example_data/preprocessed example_data/preprocessed

View File

@ -1 +1,6 @@
__version__ = '1.0.8' import logging
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
__version__ = "1.1.1"

View File

@ -1,9 +1,13 @@
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.train import train
__all__ = [ __all__ = [
"cli", "cli",
"detect", "detect",
"data",
"train",
] ]

22
batdetect2/cli/ascii.py Normal file
View File

@ -0,0 +1,22 @@
BATDETECT_ASCII_ART = """ .
=#%: .%%#
:%%%: .%%%%.
%%%%.-===::%%%%*
=%%%%+++++++%%%#.
-: .%%%#====+++#%%# .-
.+***= . =++. : .=*+#%*= :***.
=+****+++==:%+#=+% *##%%%%*=##*#**-=
++***+**+=: ##.. +##%%########**++
.++*****#*+- :*:++ ##%#%%%%%####**++
.++***+**++++- :#%%%%%####*##***+=
.+++***+#+++*########%%%##%#+*****++:
.=++++++*+++##%##%%####%%##*:+****+=
=++++++====*#%%#%###%%###- +***+++.
.+*++++= =+==##########= :****++.
=++*+:. .:=#####= .++**++-
.****: . -+**++=
*###= .****==
.#*#- **#*:
-### -*##.
+*= *#*
"""

View File

@ -1,18 +1,14 @@
"""BatDetect2 command line interface.""" """BatDetect2 command line interface."""
import os
import click import click
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
__all__ = [ __all__ = [
"cli", "cli",
] ]
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
INFO_STR = """ INFO_STR = """
BatDetect2 - Detection and Classification BatDetect2 - Detection and Classification
Assumes audio files are mono, not stereo. Assumes audio files are mono, not stereo.
@ -25,3 +21,4 @@ BatDetect2 - Detection and Classification
def cli(): def cli():
"""BatDetect2 - Bat Call Detection and Classification.""" """BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR) click.echo(INFO_STR)
# click.echo(BATDETECT_ASCII_ART)

View File

@ -1,14 +1,11 @@
"""BatDetect2 command line interface."""
import click import click
from batdetect2 import api from batdetect2 import api
from batdetect2.cli.base import cli
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ProcessingConfiguration from batdetect2.types import ProcessingConfiguration
from batdetect2.utils.detector_utils import save_results_to_file from batdetect2.utils.detector_utils import save_results_to_file
from batdetect2.cli.base import cli
@cli.command() @cli.command()
@click.argument( @click.argument(
@ -114,10 +111,9 @@ def detect(
): ):
results_path = audio_file.replace(audio_dir, ann_dir) results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path) save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err: except (RuntimeError, ValueError, LookupError, EOFError) as err:
error_files.append(audio_file) error_files.append(audio_file)
click.secho(f"Error processing file!: {err}", fg="red") click.secho(f"Error processing file {audio_file}: {err}", fg="red")
raise err
click.echo(f"\nResults saved to: {ann_dir}") click.echo(f"\nResults saved to: {ann_dir}")

40
batdetect2/cli/data.py Normal file
View File

@ -0,0 +1,40 @@
from pathlib import Path
from typing import Optional
import click
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
__all__ = ["data"]
@cli.group()
def data(): ...
@data.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.option(
"--field",
type=str,
help="If the dataset info is in a nested field please specify here.",
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help="The base directory to which all recording and annotations paths are relative to.",
)
def summary(
dataset_config: Path,
field: Optional[str] = None,
base_dir: Optional[Path] = None,
):
base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config(
dataset_config, field=field, base_dir=base_dir
)
print(f"Number of annotated clips: {len(dataset.clip_annotations)}")

188
batdetect2/cli/train.py Normal file
View File

@ -0,0 +1,188 @@
from pathlib import Path
from typing import Optional
import click
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import (
load_preprocessing_config,
)
from batdetect2.train import (
load_label_config,
load_target_config,
preprocess_annotations,
)
__all__ = ["train"]
@cli.group()
def train(): ...
@train.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.argument(
"output",
type=click.Path(),
)
@click.option(
"--dataset-field",
type=str,
help=(
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
)
@click.option(
"--preprocess-config",
type=click.Path(exists=True),
help=(
"Path to the preprocessing configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
)
@click.option(
"--preprocess-config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
)
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
)
@click.option(
"--force",
is_flag=True,
help=(
"If a preprocessed file already exists, this option tells the "
"program to overwrite it with the new preprocessed data. Use "
"this if you want to re-do the preprocessing even if the files "
"already exist."
),
)
@click.option(
"--num-workers",
type=int,
help=(
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
)
def preprocess(
dataset_config: Path,
output: Path,
base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None,
target_config: Optional[Path] = None,
label_config: Optional[Path] = None,
force: bool = False,
num_workers: Optional[int] = None,
target_config_field: Optional[str] = None,
preprocess_config_field: Optional[str] = None,
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None,
):
output = Path(output)
base_dir = base_dir or Path.cwd()
preprocess = (
load_preprocessing_config(
preprocess_config,
field=preprocess_config_field,
)
if preprocess_config
else None
)
target = (
load_target_config(
target_config,
field=target_config_field,
)
if target_config
else None
)
label = (
load_label_config(
label_config,
field=label_config_field,
)
if label_config
else None
)
dataset = load_dataset_from_config(
dataset_config,
field=dataset_field,
base_dir=base_dir,
)
if not output.exists():
output.mkdir(parents=True)
preprocess_annotations(
dataset.clip_annotations,
output_dir=output,
replace=force,
preprocessing_config=preprocess,
label_config=label,
target_config=target,
max_workers=num_workers,
)

View File

View File

@ -9,15 +9,16 @@ import numpy as np
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2 import types from batdetect2 import types
from batdetect2.data.labels import ClassMapper
PathLike = Union[Path, str, os.PathLike] PathLike = Union[Path, str, os.PathLike]
__all__ = [ __all__ = [
"convert_to_annotation_group", "convert_to_annotation_group",
"load_annotation_project", "load_file_annotation",
"annotation_to_sound_event",
] ]
SPECIES_TAG_KEY = "species" SPECIES_TAG_KEY = "species"
@ -195,18 +196,30 @@ def annotation_to_sound_event(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag(key=label_key, value=annotation.label), data.Tag(
data.Tag(key=event_key, value=annotation.event), term=data.term_from_key(label_key),
data.Tag(key=individual_key, value=str(annotation.individual)), value=annotation.label,
),
data.Tag(
term=data.term_from_key(event_key),
value=annotation.event,
),
data.Tag(
term=data.term_from_key(individual_key),
value=str(annotation.individual),
),
], ],
) )
def file_annotation_to_clip( def file_annotation_to_clip(
file_annotation: FileAnnotation, file_annotation: FileAnnotation,
audio_dir: PathLike = Path.cwd(), audio_dir: Optional[PathLike] = None,
label_key: str = "class",
) -> data.Clip: ) -> data.Clip:
"""Convert file annotation to recording.""" """Convert file annotation to recording."""
audio_dir = audio_dir or Path.cwd()
full_path = Path(audio_dir) / file_annotation.id full_path = Path(audio_dir) / file_annotation.id
if not full_path.exists(): if not full_path.exists():
@ -215,6 +228,12 @@ def file_annotation_to_clip(
recording = data.Recording.from_file( recording = data.Recording.from_file(
full_path, full_path,
time_expansion=file_annotation.time_exp, time_expansion=file_annotation.time_exp,
tags=[
data.Tag(
term=data.term_from_key(label_key),
value=file_annotation.label,
)
],
) )
return data.Clip( return data.Clip(
@ -241,7 +260,11 @@ def file_annotation_to_clip_annotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"), uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
clip=clip, clip=clip,
notes=notes, notes=notes,
tags=[data.Tag(key=label_key, value=file_annotation.label)], tags=[
data.Tag(
term=data.term_from_key(label_key), value=file_annotation.label
)
],
sound_events=[ sound_events=[
annotation_to_sound_event( annotation_to_sound_event(
annotation, annotation,
@ -281,52 +304,3 @@ def list_file_annotations(path: PathLike) -> List[Path]:
"""List all annotations in a directory.""" """List all annotations in a directory."""
path = Path(path) path = Path(path)
return [file for file in path.glob("*.json")] return [file for file in path.glob("*.json")]
def load_annotation_project(
path: PathLike,
name: Optional[str] = None,
audio_dir: PathLike = Path.cwd(),
) -> data.AnnotationProject:
"""Convert annotations to annotation project."""
paths = list_file_annotations(path)
if name is None:
name = str(path)
annotations = []
tasks = []
for p in paths:
try:
file_annotation = load_file_annotation(p)
except FileNotFoundError:
continue
try:
clip = file_annotation_to_clip(
file_annotation,
audio_dir=audio_dir,
)
except FileNotFoundError:
continue
annotations.append(
file_annotation_to_clip_annotation(
file_annotation,
clip,
)
)
tasks.append(
file_annotation_to_annotation_task(
file_annotation,
clip,
)
)
return data.AnnotationProject(
name=name,
clip_annotations=annotations,
tasks=tasks,
)

151
batdetect2/compat/params.py Normal file
View File

@ -0,0 +1,151 @@
from batdetect2.preprocess import (
AmplitudeScaleConfig,
AudioConfig,
FrequencyConfig,
LogScaleConfig,
PcenScaleConfig,
PreprocessingConfig,
ResampleConfig,
Scales,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
)
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.terms import TagInfo
from batdetect2.train.preprocess import (
HeatmapsConfig,
TargetConfig,
TrainPreprocessingConfig,
)
def get_spectrogram_scale(scale: str) -> Scales:
if scale == "pcen":
return PcenScaleConfig()
if scale == "log":
return LogScaleConfig()
return AmplitudeScaleConfig()
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
return PreprocessingConfig(
audio=AudioConfig(
resample=ResampleConfig(
samplerate=params["target_samp_rate"],
mode="poly",
),
scale=params["scale_raw_audio"],
center=params["scale_raw_audio"],
duration=None,
),
spectrogram=SpectrogramConfig(
stft=STFTConfig(
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
window_fn="hann",
),
frequencies=FrequencyConfig(
min_freq=params["min_freq"],
max_freq=params["max_freq"],
),
scale=get_spectrogram_scale(params["spec_scale"]),
denoise=params["denoise_spec_avg"],
size=SpecSizeConfig(
height=params["spec_height"],
resize_factor=params["resize_factor"],
),
max_scale=params["max_scale_spec"],
),
)
def get_training_preprocessing_config(
params: dict,
) -> TrainPreprocessingConfig:
generic = params["generic_class"][0]
preprocessing = get_preprocessing_config(params)
freq_bin_width, time_bin_width = get_spectrogram_resolution(
preprocessing.spectrogram
)
return TrainPreprocessingConfig(
preprocessing=preprocessing,
target=TargetConfig(
classes=[
TagInfo(key="class", value=class_name, label=class_name)
for class_name in params["class_names"]
],
generic_class=TagInfo(
key="class",
value=generic,
label=generic,
),
include=[
TagInfo(key="event", value=event)
for event in params["events_of_interest"]
],
exclude=[
TagInfo(key="class", value=value)
for value in params["classes_to_ignore"]
],
),
heatmaps=HeatmapsConfig(
position="bottom-left",
time_scale=1 / time_bin_width,
frequency_scale=1 / freq_bin_width,
sigma=params["target_sigma"],
),
)
# 'standardize_classs_names_ip',
# 'convert_to_genus',
# 'genus_mapping',
# 'standardize_classs_names',
# 'genus_names',
# ['data_dir',
# 'ann_dir',
# 'train_split',
# 'model_name',
# 'num_filters',
# 'experiment',
# 'model_file_name',
# 'op_im_dir',
# 'op_im_dir_test',
# 'notes',
# 'spec_divide_factor',
# 'detection_overlap',
# 'ignore_start_end',
# 'detection_threshold',
# 'nms_kernel_size',
# 'nms_top_k_per_sec',
# 'aug_prob',
# 'augment_at_train',
# 'augment_at_train_combine',
# 'echo_max_delay',
# 'stretch_squeeze_delta',
# 'mask_max_time_perc',
# 'mask_max_freq_perc',
# 'spec_amp_scaling',
# 'aug_sampling_rates',
# 'train_loss',
# 'det_loss_weight',
# 'size_loss_weight',
# 'class_loss_weight',
# 'individual_loss_weight',
# 'emb_dim',
# 'lr',
# 'batch_size',
# 'num_workers',
# 'num_epochs',
# 'num_eval_epochs',
# 'device',
# 'save_test_image_during_train',
# 'save_test_image_after_train',
# 'train_sets',
# 'test_sets',
# 'class_inv_freq',
# 'ip_height']

35
batdetect2/configs.py Normal file
View File

@ -0,0 +1,35 @@
from typing import Any, Optional, Type, TypeVar
import yaml
from pydantic import BaseModel, ConfigDict
from soundevent.data import PathLike
class BaseConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
T = TypeVar("T", bound=BaseModel)
def get_object_field(obj: dict, field: str) -> Any:
if "." not in field:
return obj[field]
field, rest = field.split(".", 1)
subobj = obj[field]
return get_object_field(subobj, rest)
def load_config(
path: PathLike,
schema: Type[T],
field: Optional[str] = None,
) -> T:
with open(path, "r") as file:
config = yaml.safe_load(file)
if field:
config = get_object_field(config, field)
return schema.model_validate(config)

View File

@ -0,0 +1,14 @@
from batdetect2.data.annotations import (
AnnotatedDataset,
load_annotated_dataset,
)
from batdetect2.data.data import load_dataset, load_dataset_from_config
from batdetect2.data.types import Dataset
__all__ = [
"AnnotatedDataset",
"Dataset",
"load_annotated_dataset",
"load_dataset",
"load_dataset_from_config",
]

View File

@ -0,0 +1,36 @@
import json
from pathlib import Path
from typing import Literal, Union
from batdetect2.configs import BaseConfig
__all__ = [
"AOEFAnnotationFile",
"AnnotationFormats",
"BatDetect2AnnotationFile",
"BatDetect2AnnotationFiles",
]
class BatDetect2AnnotationFiles(BaseConfig):
format: Literal["batdetect2"] = "batdetect2"
path: Path
class BatDetect2AnnotationFile(BaseConfig):
format: Literal["batdetect2_file"] = "batdetect2_file"
path: Path
class AOEFAnnotationFile(BaseConfig):
format: Literal["aoef"] = "aoef"
path: Path
AnnotationFormats = Union[
BatDetect2AnnotationFiles,
BatDetect2AnnotationFile,
AOEFAnnotationFile,
]

View File

@ -0,0 +1,55 @@
from pathlib import Path
from typing import Optional, Union
from soundevent import data
from batdetect2.data.annotations.aeof import (
AOEFAnnotations,
load_aoef_annotated_dataset,
)
from batdetect2.data.annotations.batdetect2_files import (
BatDetect2FilesAnnotations,
load_batdetect2_files_annotated_dataset,
)
from batdetect2.data.annotations.batdetect2_merged import (
BatDetect2MergedAnnotations,
load_batdetect2_merged_annotated_dataset,
)
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
"load_annotated_dataset",
"AnnotatedDataset",
"AOEFAnnotations",
"BatDetect2FilesAnnotations",
"BatDetect2MergedAnnotations",
"AnnotationFormats",
]
AnnotationFormats = Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
]
def load_annotated_dataset(
dataset: AnnotatedDataset,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
if isinstance(dataset, AOEFAnnotations):
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
if isinstance(dataset, BatDetect2MergedAnnotations):
return load_batdetect2_merged_annotated_dataset(
dataset, base_dir=base_dir
)
if isinstance(dataset, BatDetect2FilesAnnotations):
return load_batdetect2_files_annotated_dataset(
dataset,
base_dir=base_dir,
)
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")

View File

@ -0,0 +1,37 @@
from pathlib import Path
from typing import Literal, Optional
from soundevent import data, io
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
"AOEFAnnotations",
"load_aoef_annotated_dataset",
]
class AOEFAnnotations(AnnotatedDataset):
format: Literal["aoef"] = "aoef"
annotations_path: Path
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
loaded = io.load(path, audio_dir=audio_dir)
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
raise ValueError(
f"The AOEF file at {path} does not contain a set of annotations"
)
return loaded

View File

@ -0,0 +1,80 @@
import os
from pathlib import Path
from typing import Literal, Optional, Union
from soundevent import data
from batdetect2.data.annotations.legacy import (
file_annotation_to_annotation_task,
file_annotation_to_clip,
file_annotation_to_clip_annotation,
list_file_annotations,
load_file_annotation,
)
from batdetect2.data.annotations.types import AnnotatedDataset
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"load_batdetect2_files_annotated_dataset",
"BatDetect2FilesAnnotations",
]
class BatDetect2FilesAnnotations(AnnotatedDataset):
format: Literal["batdetect2"] = "batdetect2"
annotations_dir: Path
def load_batdetect2_files_annotated_dataset(
dataset: BatDetect2FilesAnnotations,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationProject:
"""Convert annotations to annotation project."""
audio_dir = dataset.audio_dir
path = dataset.annotations_dir
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
paths = list_file_annotations(path)
annotations = []
tasks = []
for p in paths:
try:
file_annotation = load_file_annotation(p)
except FileNotFoundError:
continue
try:
clip = file_annotation_to_clip(
file_annotation,
audio_dir=audio_dir,
)
except FileNotFoundError:
continue
annotations.append(
file_annotation_to_clip_annotation(
file_annotation,
clip,
)
)
tasks.append(
file_annotation_to_annotation_task(
file_annotation,
clip,
)
)
return data.AnnotationProject(
name=dataset.name,
description=dataset.description,
clip_annotations=annotations,
tasks=tasks,
)

View File

@ -0,0 +1,64 @@
import json
import os
from pathlib import Path
from typing import Literal, Optional, Union
from soundevent import data
from batdetect2.data.annotations.legacy import (
FileAnnotation,
file_annotation_to_annotation_task,
file_annotation_to_clip,
file_annotation_to_clip_annotation,
)
from batdetect2.data.annotations.types import AnnotatedDataset
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"BatDetect2MergedAnnotations",
"load_batdetect2_merged_annotated_dataset",
]
class BatDetect2MergedAnnotations(AnnotatedDataset):
format: Literal["batdetect2_file"] = "batdetect2_file"
annotations_path: Path
def load_batdetect2_merged_annotated_dataset(
dataset: BatDetect2MergedAnnotations,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationProject:
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
content = json.loads(Path(path).read_text())
annotations = []
tasks = []
for ann in content:
try:
ann = FileAnnotation.model_validate(ann)
except ValueError:
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError:
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))
tasks.append(file_annotation_to_annotation_task(ann, clip))
return data.AnnotationProject(
name=dataset.name,
description=dataset.description,
clip_annotations=annotations,
tasks=tasks,
)

View File

@ -0,0 +1,304 @@
"""Compatibility functions between old and new data structures."""
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
from pydantic import BaseModel, Field
from soundevent import data
from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2 import types
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"convert_to_annotation_group",
]
SPECIES_TAG_KEY = "species"
ECHOLOCATION_EVENT = "Echolocation"
UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
ClassFn = Callable[[data.Recording], int]
IndividualFn = Callable[[data.SoundEventAnnotation], int]
def get_recording_class_name(recording: data.Recording) -> str:
"""Get the class name for a recording."""
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
if tag is None:
return UNKNOWN_CLASS
return tag.value
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
"""Get the notes for a ClipAnnotation."""
all_notes = [
*annotation.notes,
*annotation.clip.recording.notes,
]
messages = [note.message for note in all_notes if note.message is not None]
return "\n".join(messages)
def convert_to_annotation_group(
annotation: data.ClipAnnotation,
class_mapper: ClassMapper,
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
class_fn: ClassFn = lambda _: 0,
individual_fn: IndividualFn = lambda _: 0,
) -> types.AudioLoaderAnnotationGroup:
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
recording = annotation.clip.recording
start_times = []
end_times = []
low_freqs = []
high_freqs = []
class_ids = []
x_inds = []
y_inds = []
individual_ids = []
annotations: List[types.Annotation] = []
class_id_file = class_fn(recording)
for sound_event in annotation.sound_events:
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
class_id = class_mapper.transform(sound_event) or -1
event = event_fn(sound_event) or ""
individual_id = individual_fn(sound_event) or -1
start_times.append(start_time)
end_times.append(end_time)
low_freqs.append(low_freq)
high_freqs.append(high_freq)
class_ids.append(class_id)
individual_ids.append(individual_id)
# NOTE: This will be computed later so we just put a placeholder
# here for now.
x_inds.append(0)
y_inds.append(0)
annotations.append(
{
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": 1.0,
"det_prob": 1.0,
"individual": "0",
"event": event,
"class_id": class_id, # type: ignore
}
)
return {
"id": str(recording.path),
"duration": recording.duration,
"issues": False,
"file_path": str(recording.path),
"time_exp": recording.time_expansion,
"class_name": get_recording_class_name(recording),
"notes": get_annotation_notes(annotation),
"annotated": True,
"start_times": np.array(start_times),
"end_times": np.array(end_times),
"low_freqs": np.array(low_freqs),
"high_freqs": np.array(high_freqs),
"class_ids": np.array(class_ids),
"x_inds": np.array(x_inds),
"y_inds": np.array(y_inds),
"individual_ids": np.array(individual_ids),
"annotation": annotations,
"class_id_file": class_id_file,
}
class Annotation(BaseModel):
"""Annotation class to hold batdetect annotations."""
label: str = Field(alias="class")
event: str
individual: int = 0
start_time: float
end_time: float
low_freq: float
high_freq: float
class FileAnnotation(BaseModel):
"""FileAnnotation class to hold batdetect annotations for a file."""
id: str
duration: float
time_exp: float = 1
label: str = Field(alias="class_name")
annotation: List[Annotation]
annotated: bool = False
issues: bool = False
notes: str = ""
def load_file_annotation(path: PathLike) -> FileAnnotation:
"""Load annotation from batdetect format."""
path = Path(path)
return FileAnnotation.model_validate_json(path.read_text())
def annotation_to_sound_event(
annotation: Annotation,
recording: data.Recording,
label_key: str = "class",
event_key: str = "event",
individual_key: str = "individual",
) -> data.SoundEventAnnotation:
"""Convert annotation to sound event annotation."""
sound_event = data.SoundEvent(
uuid=uuid.uuid5(
NAMESPACE,
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
),
recording=recording,
geometry=data.BoundingBox(
coordinates=[
annotation.start_time,
annotation.low_freq,
annotation.end_time,
annotation.high_freq,
],
),
)
return data.SoundEventAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
tags=[
data.Tag(
term=data.term_from_key(label_key),
value=annotation.label,
),
data.Tag(
term=data.term_from_key(event_key),
value=annotation.event,
),
data.Tag(
term=data.term_from_key(individual_key),
value=str(annotation.individual),
),
],
)
def file_annotation_to_clip(
file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None,
label_key: str = "class",
) -> data.Clip:
"""Convert file annotation to recording."""
audio_dir = audio_dir or Path.cwd()
full_path = Path(audio_dir) / file_annotation.id
if not full_path.exists():
raise FileNotFoundError(f"File {full_path} not found.")
recording = data.Recording.from_file(
full_path,
time_expansion=file_annotation.time_exp,
tags=[
data.Tag(
term=data.term_from_key(label_key),
value=file_annotation.label,
)
],
)
return data.Clip(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
recording=recording,
start_time=0,
end_time=recording.duration,
)
def file_annotation_to_clip_annotation(
file_annotation: FileAnnotation,
clip: data.Clip,
label_key: str = "class",
event_key: str = "event",
individual_key: str = "individual",
) -> data.ClipAnnotation:
"""Convert file annotation to clip annotation."""
notes = []
if file_annotation.notes:
notes.append(data.Note(message=file_annotation.notes))
return data.ClipAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
clip=clip,
notes=notes,
tags=[
data.Tag(
term=data.term_from_key(label_key), value=file_annotation.label
)
],
sound_events=[
annotation_to_sound_event(
annotation,
clip.recording,
label_key=label_key,
event_key=event_key,
individual_key=individual_key,
)
for annotation in file_annotation.annotation
],
)
def file_annotation_to_annotation_task(
file_annotation: FileAnnotation,
clip: data.Clip,
) -> data.AnnotationTask:
status_badges = []
if file_annotation.issues:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.rejected)
)
elif file_annotation.annotated:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.completed)
)
return data.AnnotationTask(
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
clip=clip,
status_badges=status_badges,
)
def list_file_annotations(path: PathLike) -> List[Path]:
"""List all annotations in a directory."""
path = Path(path)
return [file for file in path.glob("*.json")]

View File

@ -0,0 +1,41 @@
from pathlib import Path
from typing import Literal, Union
from batdetect2.configs import BaseConfig
__all__ = [
"AnnotatedDataset",
"BatDetect2MergedAnnotations",
]
class AnnotatedDataset(BaseConfig):
"""Represents a single, cohesive source of audio recordings and annotations.
A source typically groups recordings originating from a specific context,
such as a single project, site, deployment, or recordist. All audio files
belonging to a source should be located within a single directory,
specified by `audio_dir`.
Annotations associated with these recordings are defined by the
`annotations` field, which supports various formats (e.g., AOEF files,
specific CSV
structures).
Crucially, file paths referenced within the annotation data *must* be
relative to the `audio_dir`. This ensures that the dataset definition
remains portable across different systems and base directories.
Attributes:
name: A unique identifier for this data source.
description: Detailed information about the source, including recording
methods, annotation procedures, equipment used, potential biases,
or any important caveats for users.
audio_dir: The file system path to the directory containing the audio
recordings for this source.
"""
name: str
audio_dir: Path
description: str = ""

37
batdetect2/data/data.py Normal file
View File

@ -0,0 +1,37 @@
from pathlib import Path
from typing import Optional
from soundevent import data
from batdetect2.configs import load_config
from batdetect2.data.annotations import load_annotated_dataset
from batdetect2.data.types import Dataset
__all__ = [
"load_dataset",
"load_dataset_from_config",
]
def load_dataset(
dataset: Dataset,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
clip_annotations = []
for source in dataset.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
clip_annotations.extend(annotated_source.clip_annotations)
return data.AnnotationSet(clip_annotations=clip_annotations)
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[Path] = None,
):
config = load_config(
path=path,
schema=Dataset,
field=field,
)
return load_dataset(config, base_dir=base_dir)

View File

@ -1,33 +0,0 @@
from typing import Callable, Generic, Iterable, List, TypeVar
from soundevent import data
from torch.utils.data import Dataset
__all__ = [
"ClipDataset",
]
E = TypeVar("E")
class ClipDataset(Dataset, Generic[E]):
clips: List[data.Clip]
transform: Callable[[data.Clip], E]
def __init__(
self,
clips: Iterable[data.Clip],
transform: Callable[[data.Clip], E],
name: str = "ClipDataset",
):
self.clips = list(clips)
self.transform = transform
self.name = name
def __len__(self) -> int:
return len(self.clips)
def __getitem__(self, idx: int) -> E:
return self.transform(self.clips[idx])

View File

@ -1,392 +0,0 @@
"""Module containing functions for preprocessing audio clips."""
from typing import Optional, Union
from pathlib import Path
import librosa
import librosa.core.spectrum
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
from pydantic import BaseModel, Field
from scipy.signal import resample_poly
from soundevent import audio, data, arrays
from soundevent.arrays import operations as ops
__all__ = [
"PreprocessingConfig",
"preprocess_audio_clip",
]
TARGET_SAMPLERATE_HZ = 256000
SCALE_RAW_AUDIO = False
FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75
MAX_FREQ_HZ = 120000
MIN_FREQ_HZ = 10000
DEFAULT_DURATION = 1
SPEC_HEIGHT = 128
SPEC_WIDTH = 256
SPEC_SCALE = "pcen"
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False
class PreprocessingConfig(BaseModel):
"""Configuration for preprocessing data."""
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
spec_scale: str = Field(default=SPEC_SCALE)
denoise_spec_avg: bool = DENOISE_SPEC_AVG
max_scale_spec: bool = MAX_SCALE_SPEC
duration: Optional[float] = DEFAULT_DURATION
spec_height: int = SPEC_HEIGHT
spec_time_period: float = SPEC_TIME_PERIOD
@classmethod
def from_file(
cls,
path: Union[str, Path],
) -> "PreprocessingConfig":
"""Load configuration from a file.
Parameters
----------
path
Path to the configuration file.
Returns
-------
PreprocessingConfig
The configuration loaded from the file.
Raises
------
FileNotFoundError
If the configuration file does not exist.
pydantic.ValidationError
If the configuration file is invalid.
"""
path = Path(path)
if not path.is_file():
raise FileNotFoundError(f"Config file not found: {path}")
return cls.model_validate_json(path.read_text())
def to_file(self, path: Union[str, Path]) -> None:
"""Save configuration to a file."""
path = Path(path)
if not path.parent.exists():
path.parent.mkdir(parents=True)
path.write_text(self.model_dump_json())
def preprocess_audio_clip(
clip: data.Clip,
config: PreprocessingConfig = PreprocessingConfig(),
) -> xr.DataArray:
"""Preprocesses audio clip to generate spectrogram.
Parameters
----------
clip
The audio clip to preprocess.
config
Configuration for preprocessing.
Returns
-------
xr.DataArray
Preprocessed spectrogram.
"""
wav = load_clip_audio(
clip,
target_sampling_rate=config.target_samplerate,
scale=config.scale_audio,
)
spec = compute_spectrogram(
wav,
fft_win_length=config.fft_win_length,
fft_overlap=config.fft_overlap,
max_freq=config.max_freq,
min_freq=config.min_freq,
spec_scale=config.spec_scale,
denoise_spec_avg=config.denoise_spec_avg,
max_scale_spec=config.max_scale_spec,
)
if config.duration is not None:
spec = adjust_spec_duration(clip, spec, config.duration)
duration = arrays.get_dim_width(spec, dim="time")
return ops.resize(
spec,
time=int(np.ceil(duration / config.spec_time_period)),
frequency=config.spec_height,
dtype=np.float32,
)
def adjust_spec_duration(
clip: data.Clip,
spec: xr.DataArray,
duration: float,
) -> xr.DataArray:
current_duration = clip.end_time - clip.start_time
if current_duration == duration:
return spec
if current_duration > duration:
return arrays.crop_dim(
spec,
dim="time",
start=clip.start_time,
stop=clip.start_time + duration,
)
return arrays.extend_dim(
spec,
dim="time",
start=clip.start_time,
stop=clip.start_time + duration,
)
def load_clip_audio(
clip: data.Clip,
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = SCALE_RAW_AUDIO,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
if scale:
wav = ops.center(wav)
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
return wav.astype(dtype)
def resample_audio(
wav: xr.DataArray,
target_samplerate: int = TARGET_SAMPLERATE_HZ,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension")
time_axis: int = wav.get_axis_num("time") # type: ignore
start, stop = arrays.get_dim_range(wav, dim="time")
step = arrays.get_dim_step(wav, dim="time")
original_samplerate = int(1 / step)
if original_samplerate == target_samplerate:
return wav.astype(dtype)
gcd = np.gcd(original_samplerate, target_samplerate)
resampled = resample_poly(
wav.values,
target_samplerate // gcd,
original_samplerate // gcd,
axis=time_axis,
)
resampled_times = np.linspace(
start,
stop + step,
len(resampled),
endpoint=False,
dtype=dtype,
)
return xr.DataArray(
data=resampled.astype(dtype),
dims=wav.dims,
coords={
**wav.coords,
"time": arrays.create_time_dim_from_array(
resampled_times,
samplerate=target_samplerate,
),
},
attrs=wav.attrs,
)
def compute_spectrogram(
wav: xr.DataArray,
fft_win_length: float = FFT_WIN_LENGTH_S,
fft_overlap: float = FFT_OVERLAP,
max_freq: int = MAX_FREQ_HZ,
min_freq: int = MIN_FREQ_HZ,
spec_scale: str = SPEC_SCALE,
denoise_spec_avg: bool = True,
max_scale_spec: bool = False,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
spec = gen_mag_spectrogram(
wav,
window_len=fft_win_length,
overlap_perc=fft_overlap,
dtype=dtype,
)
spec = arrays.crop_dim(
spec,
dim="frequency",
start=min_freq,
stop=max_freq,
).astype(dtype)
spec = scale_spectrogram(spec, scale=spec_scale)
if denoise_spec_avg:
spec = denoise_spectrogram(spec)
if max_scale_spec:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
return spec.astype(dtype)
def gen_mag_spectrogram(
wave: xr.DataArray,
window_len: float,
overlap_perc: float,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
start_time, end_time = arrays.get_dim_range(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time")
sampling_rate = 1 / step
hop_len = window_len * (1 - overlap_perc)
nfft = int(window_len * sampling_rate)
noverlap = int(overlap_perc * nfft)
# compute spec
spec, _ = librosa.core.spectrum._spectrogram(
y=wave.data,
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
)
return xr.DataArray(
data=spec.astype(dtype),
dims=["frequency", "time"],
coords={
"frequency": arrays.create_frequency_dim_from_array(
np.linspace(
0,
sampling_rate / 2,
spec.shape[0],
endpoint=False,
dtype=dtype,
),
step=sampling_rate / nfft,
),
"time": arrays.create_time_dim_from_array(
np.linspace(
start_time,
end_time - (window_len - hop_len),
spec.shape[1],
endpoint=False,
dtype=dtype,
),
step=hop_len,
),
},
attrs={
**wave.attrs,
"original_samplerate": sampling_rate,
"nfft": nfft,
"noverlap": noverlap,
},
)
def denoise_spectrogram(
spec: xr.DataArray,
) -> xr.DataArray:
return xr.DataArray(
data=(spec - spec.mean("time")).clip(0),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,
)
def scale_spectrogram(
spec: xr.DataArray,
scale: str = SPEC_SCALE,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"]
if scale == "pcen":
smoothing_constant = get_pcen_smoothing_constant(samplerate / 10)
return audio.pcen(
spec * (2**31),
smooth=smoothing_constant,
).astype(dtype)
if scale == "log":
return log_scale(spec, dtype=dtype)
return spec
def log_scale(
spec: xr.DataArray,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"]
nfft = spec.attrs["nfft"]
log_scaling = (
2.0
* (1.0 / samplerate)
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
)
return xr.DataArray(
data=np.log1p(log_scaling * spec).astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,
)
def get_pcen_smoothing_constant(
sr: int,
time_constant: float = 0.4,
hop_length: int = 512,
) -> float:
t_frames = time_constant * sr / float(hop_length)
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)

29
batdetect2/data/types.py Normal file
View File

@ -0,0 +1,29 @@
from typing import Annotated, List
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.data.annotations import AnnotationFormats
class Dataset(BaseConfig):
"""Represents a collection of one or more DatasetSources.
In the context of batdetect2, a Dataset aggregates multiple `DatasetSource`
instances. It serves as the primary unit for defining data splits,
typically used for model training, validation, or testing phases.
Attributes:
name: A descriptive name for the overall dataset
(e.g., "UK Training Set").
description: A detailed explanation of the dataset's purpose,
composition, how it was assembled, or any specific characteristics.
sources: A list containing the `DatasetSource` objects included in this
dataset.
"""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]

View File

@ -1,4 +1,5 @@
"""Functions to compute features from predictions.""" """Functions to compute features from predictions."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np import numpy as np
@ -219,7 +220,6 @@ def compute_call_interval(
return round(prediction["start_time"] - previous["end_time"], 5) return round(prediction["start_time"] - previous["end_time"], 5)
# NOTE: The order of the features in this dictionary is important. The # NOTE: The order of the features in this dictionary is important. The
# features are extracted in this order and the order of the columns in the # features are extracted in this order and the order of the columns in the
# output csv file is determined by this order. In order to avoid breaking # output csv file is determined by this order. In order to avoid breaking

View File

@ -206,7 +206,10 @@ class Net2DFastNoAttn(nn.Module):
num_filts // 4, 2, kernel_size=1, padding=0 num_filts // 4, 2, kernel_size=1, padding=0
) )
self.conv_classes_op = nn.Conv2d( self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0, num_filts // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
) )
if self.emb_dim > 0: if self.emb_dim > 0:

View File

@ -5,7 +5,10 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field, computed_field from pydantic import BaseModel, Field, computed_field
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names from batdetect2.train.legacy.train_utils import (
get_genus_mapping,
get_short_class_names,
)
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
TARGET_SAMPLERATE_HZ = 256000 TARGET_SAMPLERATE_HZ = 256000

View File

@ -1,4 +1,5 @@
"""Post-processing of the output of the model.""" """Post-processing of the output of the model."""
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np import numpy as np

View File

@ -1,16 +1,66 @@
from typing import List
import numpy as np import numpy as np
from sklearn.metrics import ( import pandas as pd
accuracy_score, from sklearn.metrics import auc, roc_curve
auc, from soundevent import data
balanced_accuracy_score, from soundevent.evaluation import match_geometries
roc_curve,
) from batdetect2.train.targets import build_encoder, get_class_names
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
) -> List[data.Match]:
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None
matches.append(
data.Match(source=source, target=target, affinity=affinity)
)
return matches
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
ret = []
for match in matches:
pass
def compute_error_auc(op_str, gt, pred, prob): def compute_error_auc(op_str, gt, pred, prob):
# classification error # classification error
pred_int = (pred > prob).astype(np.int) pred_int = (pred > prob).astype(np.int32)
class_acc = (pred_int == gt).mean() * 100.0 class_acc = (pred_int == gt).mean() * 100.0
# ROC - area under curve # ROC - area under curve
@ -25,7 +75,6 @@ def compute_error_auc(op_str, gt, pred, prob):
def calc_average_precision(recall, precision): def calc_average_precision(recall, precision):
precision[np.isnan(precision)] = 0 precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0 recall[np.isnan(recall)] = 0
@ -91,7 +140,6 @@ def compute_pre_rec(
pred_class = [] pred_class = []
file_ids = [] file_ids = []
for pid, pp in enumerate(preds): for pid, pp in enumerate(preds):
# filter predicted calls that are too near the start or end of the file # filter predicted calls that are too near the start or end of the file
file_dur = gts[pid]["duration"] file_dur = gts[pid]["duration"]
valid_inds = (pp["start_times"] >= ignore_start_end) & ( valid_inds = (pp["start_times"] >= ignore_start_end) & (
@ -141,7 +189,6 @@ def compute_pre_rec(
gt_generic_class = [] gt_generic_class = []
num_positives = 0 num_positives = 0
for gg in gts: for gg in gts:
# filter ground truth calls that are too near the start or end of the file # filter ground truth calls that are too near the start or end of the file
file_dur = gg["duration"] file_dur = gg["duration"]
valid_inds = (gg["start_times"] >= ignore_start_end) & ( valid_inds = (gg["start_times"] >= ignore_start_end) & (
@ -205,7 +252,6 @@ def compute_pre_rec(
# valid detection that has not already been assigned # valid detection that has not already been assigned
if valid_det and (gt_assigned[gt_id][det_ind] == 0): if valid_det and (gt_assigned[gt_id][det_ind] == 0):
count_as_true_pos = True count_as_true_pos = True
if eval_mode == "top_class" and ( if eval_mode == "top_class" and (
gt_class[gt_id][det_ind] != pred_class[ind] gt_class[gt_id][det_ind] != pred_class[ind]

View File

@ -12,15 +12,14 @@ import pandas as pd
import torch import torch
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
import batdetect2.train.evaluate as evl import batdetect2.evaluate.legacy.evaluate_models as evl
import batdetect2.train.train_utils as tu import batdetect2.train.legacy.train_utils as tu
import batdetect2.utils.detector_utils as du import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2.detector import parameters from batdetect2.detector import parameters
def get_blank_annotation(ip_str): def get_blank_annotation(ip_str):
res = {} res = {}
res["class_name"] = "" res["class_name"] = ""
res["duration"] = -1 res["duration"] = -1
@ -77,7 +76,6 @@ def create_genus_mapping(gt_test, preds, class_names):
def load_tadarida_pred(ip_dir, dataset, file_of_interest): def load_tadarida_pred(ip_dir, dataset, file_of_interest):
res, ann = get_blank_annotation("Generated by Tadarida") res, ann = get_blank_annotation("Generated by Tadarida")
# create the annotations in the correct format # create the annotations in the correct format
@ -120,7 +118,6 @@ def load_sonobat_meta(
class_names, class_names,
only_accepted_species=True, only_accepted_species=True,
): ):
sp_dict = {} sp_dict = {}
for ss in class_names: for ss in class_names:
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3] sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
@ -182,7 +179,6 @@ def load_sonobat_meta(
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
# create the annotations in the correct format # create the annotations in the correct format
res, ann = get_blank_annotation("Generated by Sonobat") res, ann = get_blank_annotation("Generated by Sonobat")
res_c = copy.deepcopy(res) res_c = copy.deepcopy(res)
@ -221,7 +217,6 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
def bb_overlap(bb_g_in, bb_p_in): def bb_overlap(bb_g_in, bb_p_in):
freq_scale = 10000000.0 # ensure that both axis are roughly the same range freq_scale = 10000000.0 # ensure that both axis are roughly the same range
bb_g = [ bb_g = [
bb_g_in["start_time"], bb_g_in["start_time"],
@ -465,7 +460,6 @@ def check_classes_in_train(gt_list, class_names):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"op_dir", "op_dir",

View File

@ -8,10 +8,10 @@ import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.parameters as parameters import batdetect2.detector.parameters as parameters
import batdetect2.train.audio_dataloader as adl import batdetect2.train.legacy.audio_dataloader as adl
import batdetect2.train.legacy.train_model as tm
import batdetect2.train.legacy.train_utils as tu
import batdetect2.train.losses as losses import batdetect2.train.losses as losses
import batdetect2.train.train_model as tm
import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2 import types from batdetect2 import types

View File

@ -1,11 +1,92 @@
from batdetect2.models.feature_extractors import ( from enum import Enum
from typing import Optional, Tuple
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Net2DFast, Net2DFast,
Net2DFastNoAttn, Net2DFastNoAttn,
Net2DFastNoCoordConv, Net2DFastNoCoordConv,
Net2DPlain,
) )
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.typing import BackboneModel
__all__ = [ __all__ = [
"BBoxHead",
"ClassifierHead",
"ModelConfig",
"ModelType",
"Net2DFast", "Net2DFast",
"Net2DFastNoAttn", "Net2DFastNoAttn",
"Net2DFastNoCoordConv", "Net2DFastNoCoordConv",
"build_architecture",
"load_model_config",
] ]
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)
def build_architecture(
config: Optional[ModelConfig] = None,
) -> BackboneModel:
config = config or ModelConfig()
if config.name == ModelType.Net2DFast:
return Net2DFast(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
raise ValueError(f"Unknown model type: {config.name}")

View File

@ -0,0 +1,185 @@
from typing import Sequence, Tuple
import torch
import torch.nn.functional as F
from batdetect2.models.blocks import (
ConvBlock,
Decoder,
DownscalingLayer,
Encoder,
SelfAttention,
UpscalingLayer,
VerticalConv,
)
from batdetect2.models.typing import BackboneModel
__all__ = [
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
]
class Net2DPlain(BackboneModel):
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
def __init__(
self,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
):
super().__init__()
self.input_height = input_height
self.encoder_channels = tuple(encoder_channels)
self.decoder_channels = tuple(decoder_channels)
self.out_channels = out_channels
if len(encoder_channels) != len(decoder_channels):
raise ValueError(
f"Mismatched encoder and decoder channel lists. "
f"The encoder has {len(encoder_channels)} channels "
f"(implying {len(encoder_channels) - 1} layers), "
f"while the decoder has {len(decoder_channels)} channels "
f"(implying {len(decoder_channels) - 1} layers). "
f"These lengths must be equal."
)
self.divide_factor = 2 ** (len(encoder_channels) - 1)
if self.input_height % self.divide_factor != 0:
raise ValueError(
f"Input height ({self.input_height}) must be divisible by "
f"the divide factor ({self.divide_factor}). "
f"This ensures proper upscaling after downscaling to recover "
f"the original input height."
)
self.encoder = Encoder(
channels=encoder_channels,
input_height=self.input_height,
layer_type=self.downscaling_layer_type,
)
self.conv_same_1 = ConvBlock(
in_channels=encoder_channels[-1],
out_channels=bottleneck_channels,
)
# bottleneck
self.conv_vert = VerticalConv(
in_channels=bottleneck_channels,
out_channels=bottleneck_channels,
input_height=self.input_height // (2**self.encoder.depth),
)
self.decoder = Decoder(
channels=decoder_channels,
input_height=self.input_height,
layer_type=self.upscaling_layer_type,
)
self.conv_same_2 = ConvBlock(
in_channels=decoder_channels[-1],
out_channels=out_channels,
)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
# bottleneck
x = self.conv_vert(residuals[-1])
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
# decoder
x = self.decoder(x, residuals=residuals)
# Restore original size
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.conv_same_2(x)
class Net2DFast(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
def __init__(
self,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
):
super().__init__(
input_height=input_height,
encoder_channels=encoder_channels,
bottleneck_channels=bottleneck_channels,
decoder_channels=decoder_channels,
out_channels=out_channels,
)
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
# bottleneck
x = self.conv_vert(residuals[-1])
x = self.att(x)
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
# decoder
x = self.decoder(x, residuals=residuals)
# Restore original size
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.conv_same_2(x)
class Net2DFastNoAttn(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
class Net2DFastNoCoordConv(Net2DFast):
downscaling_layer_type = "ConvBlockDownStandard"
upscaling_layer_type = "ConvBlockUpStandard"
def pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
print(spec.shape)
h, w = spec.shape[2:]
h_pad = -h % factor
w_pad = -w % factor
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
def restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor:
# Restore original size
if h_pad > 0:
x = x[:, :, :-h_pad, :]
if w_pad > 0:
x = x[:, :, :, :-w_pad]
return x

View File

@ -4,18 +4,32 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build
complex neural network architectures. complex neural network architectures.
""" """
from typing import Tuple import sys
from typing import Iterable, List, Literal, Sequence, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y
__all__ = [ __all__ = [
"SelfAttention", "ConvBlock",
"ConvBlockDownCoordF", "ConvBlockDownCoordF",
"ConvBlockDownStandard", "ConvBlockDownStandard",
"ConvBlockUpF", "ConvBlockUpF",
"ConvBlockUpStandard", "ConvBlockUpStandard",
"SelfAttention",
"VerticalConv",
"DownscalingLayer",
"UpscalingLayer",
] ]
@ -25,16 +39,21 @@ class SelfAttention(nn.Module):
This module implements self-attention mechanism. This module implements self-attention mechanism.
""" """
def __init__(self, ip_dim: int, att_dim: int): def __init__(
self,
in_channels: int,
attention_channels: int,
temperature: float = 1.0,
):
super().__init__() super().__init__()
# Note, does not encode position information (absolute or realtive) # Note, does not encode position information (absolute or relative)
self.temperature = 1.0 self.temperature = temperature
self.att_dim = att_dim self.att_dim = attention_channels
self.key_fun = nn.Linear(ip_dim, att_dim) self.key_fun = nn.Linear(in_channels, attention_channels)
self.val_fun = nn.Linear(ip_dim, att_dim) self.value_fun = nn.Linear(in_channels, attention_channels)
self.que_fun = nn.Linear(ip_dim, att_dim) self.query_fun = nn.Linear(in_channels, attention_channels)
self.pro_fun = nn.Linear(att_dim, ip_dim) self.pro_fun = nn.Linear(attention_channels, in_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.squeeze(2).permute(0, 2, 1) x = x.squeeze(2).permute(0, 2, 1)
@ -43,11 +62,11 @@ class SelfAttention(nn.Module):
x, self.key_fun.weight.T x, self.key_fun.weight.T
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
query = torch.matmul( query = torch.matmul(
x, self.que_fun.weight.T x, self.query_fun.weight.T
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0) ) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
value = torch.matmul( value = torch.matmul(
x, self.val_fun.weight.T x, self.value_fun.weight.T
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0) ) + self.value_fun.bias.unsqueeze(0).unsqueeze(0)
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / ( kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
self.temperature * self.att_dim self.temperature * self.att_dim
@ -63,6 +82,66 @@ class SelfAttention(nn.Module):
return op return op
class ConvBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
pad_size: int = 1,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=pad_size,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu_(self.conv_bn(self.conv(x)))
class VerticalConv(nn.Module):
"""Convolutional layer over full height.
This layer applies a convolution that captures information across the
entire height of the input image. It uses a kernel with the same height as
the input, effectively condensing the vertical information into a single
output row.
More specifically:
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
input channels, H is the image height, and W is the image width.
* **Kernel:** (C', H, 1) where C' is the number of output channels.
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
convolution integrates information from all rows of the input.
This process effectively extracts features that span the full height of
the input image.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
input_height: int,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(input_height, 1),
padding=0,
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu_(self.bn(self.conv(x)))
class ConvBlockDownCoordF(nn.Module): class ConvBlockDownCoordF(nn.Module):
"""Convolutional Block with Downsampling and Coord Feature. """Convolutional Block with Downsampling and Coord Feature.
@ -72,27 +151,27 @@ class ConvBlockDownCoordF(nn.Module):
def __init__( def __init__(
self, self,
in_chn: int, in_channels: int,
out_chn: int, out_channels: int,
ip_height: int, input_height: int,
k_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
stride: int = 1, stride: int = 1,
): ):
super().__init__() super().__init__()
self.coords = nn.Parameter( self.coords = nn.Parameter(
torch.linspace(-1, 1, ip_height)[None, None, ..., None], torch.linspace(-1, 1, input_height)[None, None, ..., None],
requires_grad=False, requires_grad=False,
) )
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn + 1, in_channels + 1,
out_chn, out_channels,
kernel_size=k_size, kernel_size=kernel_size,
padding=pad_size, padding=pad_size,
stride=stride, stride=stride,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
@ -110,26 +189,28 @@ class ConvBlockDownStandard(nn.Module):
def __init__( def __init__(
self, self,
in_chn, in_channels: int,
out_chn, out_channels: int,
k_size=3, kernel_size: int = 3,
pad_size=1, pad_size: int = 1,
stride=1, stride: int = 1,
): ):
super(ConvBlockDownStandard, self).__init__() super(ConvBlockDownStandard, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn, in_channels,
out_chn, out_channels,
kernel_size=k_size, kernel_size=kernel_size,
padding=pad_size, padding=pad_size,
stride=stride, stride=stride,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x): def forward(self, x):
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True) return F.relu(self.conv_bn(x), inplace=True)
return x
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
class ConvBlockUpF(nn.Module): class ConvBlockUpF(nn.Module):
@ -141,10 +222,10 @@ class ConvBlockUpF(nn.Module):
def __init__( def __init__(
self, self,
in_chn: int, in_channels: int,
out_chn: int, out_channels: int,
ip_height: int, input_height: int,
k_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: Tuple[int, int] = (2, 2),
@ -154,15 +235,18 @@ class ConvBlockUpF(nn.Module):
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.coords = nn.Parameter( self.coords = nn.Parameter(
torch.linspace(-1, 1, ip_height * up_scale[0])[ torch.linspace(-1, 1, input_height * up_scale[0])[
None, None, ..., None None, None, ..., None
], ],
requires_grad=False, requires_grad=False,
) )
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size in_channels + 1,
out_channels,
kernel_size=kernel_size,
padding=pad_size,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
op = F.interpolate( op = F.interpolate(
@ -189,9 +273,9 @@ class ConvBlockUpStandard(nn.Module):
def __init__( def __init__(
self, self,
in_chn: int, in_channels: int,
out_chn: int, out_channels: int,
k_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: Tuple[int, int] = (2, 2),
@ -200,9 +284,12 @@ class ConvBlockUpStandard(nn.Module):
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn, out_chn, kernel_size=k_size, padding=pad_size in_channels,
out_channels,
kernel_size=kernel_size,
padding=pad_size,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
op = F.interpolate( op = F.interpolate(
@ -217,3 +304,143 @@ class ConvBlockUpStandard(nn.Module):
op = self.conv(op) op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True) op = F.relu(self.conv_bn(op), inplace=True)
return op return op
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
def build_downscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: DownscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockDownStandard":
return ConvBlockDownStandard(
in_channels=in_channels,
out_channels=out_channels,
)
if layer_type == "ConvBlockDownCoordF":
return ConvBlockDownCoordF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
raise ValueError(
f"Invalid downscaling layer type {layer_type}. "
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
)
class Encoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (1, 32, 62, 128),
input_height: int = 128,
layer_type: Literal[
"ConvBlockDownStandard", "ConvBlockDownCoordF"
] = "ConvBlockDownStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.layers = nn.ModuleList(
[
build_downscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height // (2**layer_num),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
)
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs = []
for layer in self.layers:
x = layer(x)
outputs.append(x)
return outputs
def build_upscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: UpscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockUpStandard":
return ConvBlockUpStandard(
in_channels=in_channels,
out_channels=out_channels,
)
if layer_type == "ConvBlockUpF":
return ConvBlockUpF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
raise ValueError(
f"Invalid upscaling layer type {layer_type}. "
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
)
class Decoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (256, 62, 32, 32),
input_height: int = 128,
layer_type: Literal[
"ConvBlockUpStandard", "ConvBlockUpF"
] = "ConvBlockUpStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.depth = len(self.channels) - 1
self.layers = nn.ModuleList(
[
build_upscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height
// (2 ** (self.depth - layer_num)),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
)
def forward(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
if len(residuals) != len(self.layers):
raise ValueError(
f"Incorrect number of residuals provided. "
f"Expected {len(self.layers)} (matching the number of layers), "
f"but got {len(residuals)}."
)
for layer, res in zip(self.layers, residuals[::-1]):
x = layer(x + res)
return x

View File

@ -0,0 +1,15 @@
import sys
from typing import Iterable, List, Literal, Sequence
import torch
from torch import nn
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -1,139 +0,0 @@
from typing import Type
import pytorch_lightning as L
import torch
import xarray as xr
from soundevent import data
from torch import nn, optim
from batdetect2.data.preprocessing import (
preprocess_audio_clip,
PreprocessingConfig,
)
from batdetect2.data.labels import ClassMapper
from batdetect2.models.feature_extractors import Net2DFast
from batdetect2.models.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
class DetectorModel(L.LightningModule):
def __init__(
self,
class_mapper: ClassMapper,
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
learning_rate: float = 1e-3,
input_height: int = 128,
num_features: int = 32,
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
postprocessing_config: PostprocessConfig = PostprocessConfig(),
):
super().__init__()
self.save_hyperparameters()
self.preprocessing_config = preprocessing_config
self.postprocessing_config = postprocessing_config
self.class_mapper = class_mapper
self.learning_rate = learning_rate
self.input_height = input_height
self.num_features = num_features
self.num_classes = class_mapper.num_classes
self.feature_extractor = feature_extractor_class(
input_height=input_height,
num_features=num_features,
)
self.classifier = nn.Conv2d(
self.feature_extractor.num_features // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
)
self.bbox = nn.Conv2d(
self.feature_extractor.num_features // 4,
2,
kernel_size=1,
padding=0,
)
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.feature_extractor(spec)
classification_logits = self.classifier(features)
classification_probs = torch.softmax(classification_logits, dim=1)
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
return ModelOutput(
detection_probs=detection_probs,
size_preds=self.bbox(features),
class_probs=classification_probs[:, :-1],
features=features,
)
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
return preprocess_audio_clip(
clip,
config=self.preprocessing_config,
)
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
spectrogram = self.compute_spectrogram(clip)
return self.feature_extractor(
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
)
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
spectrogram = self.compute_spectrogram(clip)
spec_tensor = (
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
)
outputs = self(spec_tensor)
return postprocess_model_outputs(
outputs,
[clip],
class_mapper=self.class_mapper,
config=self.postprocessing_config,
)[0]
def compute_loss(
self,
outputs: ModelOutput,
batch: TrainExample,
) -> torch.Tensor:
detection_loss = losses.focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
)
size_loss = losses.bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = losses.focal_loss(
outputs.class_probs,
batch.class_heatmap,
valid_mask=valid_mask,
)
return detection_loss + size_loss + classification_loss
def training_step( # type: ignore
self,
batch: TrainExample,
):
outputs = self.forward(batch.spec)
loss = self.compute_loss(outputs, batch)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
return [optimizer], [scheduler]

View File

@ -0,0 +1,15 @@
import sys
from typing import Iterable, List, Literal, Sequence
import torch
from torch import nn
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -1,319 +0,0 @@
from typing import Tuple
import torch
import torch.fft
import torch.nn.functional as F
from torch import nn
from batdetect2.models.blocks import (
ConvBlockDownCoordF,
ConvBlockDownStandard,
ConvBlockUpF,
ConvBlockUpStandard,
SelfAttention,
)
from batdetect2.models.typing import FeatureExtractorModel
__all__ = [
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
]
class Net2DFast(FeatureExtractorModel):
def __init__(
self,
num_features: int,
input_height: int = 128,
):
super().__init__()
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
# encoder
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_features // 4,
self.input_height,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_features // 4,
self.num_features // 2,
self.input_height // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_features // 2,
self.num_features,
self.input_height // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
# bottleneck
self.conv_1d = nn.Conv2d(
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
# decoder
self.conv_up_2 = ConvBlockUpF(
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
h, w = spec.shape[2:]
h_pad = (32 - h % 32) % 32
w_pad = (32 - w % 32) % 32
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
def forward(self, spec: torch.Tensor) -> torch.Tensor:
# encoder
spec, h_pad, w_pad = self.pad_adjust(spec)
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
# bottleneck
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x)
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
# decoder
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
# Restore original size
if h_pad > 0:
x = x[:, :, :-h_pad, :]
if w_pad > 0:
x = x[:, :, :, :-w_pad]
return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoAttn(FeatureExtractorModel):
def __init__(
self,
num_features: int,
input_height: int = 128,
):
super().__init__()
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_features // 4,
self.input_height,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_features // 4,
self.num_features // 2,
self.input_height // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_features // 2,
self.num_features,
self.input_height // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_1d = nn.Conv2d(
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_up_2 = ConvBlockUpF(
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoCoordConv(FeatureExtractorModel):
def __init__(
self,
num_features: int,
input_height: int = 128,
):
super().__init__()
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownStandard(
1,
self.num_features // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownStandard(
self.num_features // 4,
self.num_features // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownStandard(
self.num_features // 2,
self.num_features,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_1d = nn.Conv2d(
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
self.conv_up_2 = ConvBlockUpStandard(
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpStandard(
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpStandard(
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x)
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
return F.relu_(self.conv_op_bn(self.conv_op(x)))

View File

@ -0,0 +1,51 @@
from typing import NamedTuple
import torch
from torch import nn
__all__ = ["ClassifierHead"]
class Output(NamedTuple):
detection: torch.Tensor
classification: torch.Tensor
class ClassifierHead(nn.Module):
def __init__(self, num_classes: int, in_channels: int):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.classifier = nn.Conv2d(
self.in_channels,
# Add one to account for the background class
self.num_classes + 1,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> Output:
logits = self.classifier(features)
probs = torch.softmax(logits, dim=1)
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
return Output(
detection=detection_probs,
classification=probs[:, :-1],
)
class BBoxHead(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.bbox = nn.Conv2d(
in_channels=self.in_channels,
out_channels=2,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
return self.bbox(features)

View File

@ -1,12 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import NamedTuple from typing import NamedTuple, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
__all__ = [ __all__ = [
"ModelOutput", "ModelOutput",
"FeatureExtractorModel", "BackboneModel",
] ]
@ -41,16 +41,31 @@ class ModelOutput(NamedTuple):
"""Tensor with intermediate features.""" """Tensor with intermediate features."""
class FeatureExtractorModel(ABC, nn.Module): class BackboneModel(ABC, nn.Module):
input_height: int input_height: int
"""Height of the input spectrogram.""" """Height of the input spectrogram."""
num_features: int encoder_channels: Tuple[int, ...]
"""Dimension of the feature tensor.""" """Tuple specifying the number of channels for each convolutional layer
in the encoder. The length of this tuple determines the number of
encoder layers."""
decoder_channels: Tuple[int, ...]
"""Tuple specifying the number of channels for each convolutional layer
in the decoder. The length of this tuple determines the number of
decoder layers."""
bottleneck_channels: int
"""Number of channels in the bottleneck layer, which connects the
encoder and decoder."""
out_channels: int
"""Number of channels in the final output feature map produced by the
backbone model."""
@abstractmethod @abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Forward pass of the encoder model.""" """Forward pass of the model."""
class DetectionModel(ABC, nn.Module): class DetectionModel(ABC, nn.Module):

181
batdetect2/modules.py Normal file
View File

@ -0,0 +1,181 @@
from pathlib import Path
from typing import Optional
import lightning as L
import torch
from pydantic import Field
from soundevent import data
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
from batdetect2.models import (
BBoxHead,
ClassifierHead,
ModelConfig,
build_architecture,
)
from batdetect2.models.typing import ModelOutput
from batdetect2.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.losses import compute_loss
from batdetect2.train.targets import (
TargetConfig,
build_decoder,
build_encoder,
get_class_names,
)
__all__ = [
"DetectorModel",
]
class ModuleConfig(BaseConfig):
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
architecture: ModelConfig = Field(default_factory=ModelConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocessing: PostprocessConfig = Field(
default_factory=PostprocessConfig
)
class DetectorModel(L.LightningModule):
config: ModuleConfig
def __init__(
self,
config: Optional[ModuleConfig] = None,
):
super().__init__()
self.config = config or ModuleConfig()
self.save_hyperparameters()
self.backbone = build_architecture(self.config.architecture)
self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes),
in_channels=self.backbone.out_channels,
)
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
conf = self.config.train.loss.classification
self.class_weights = (
torch.tensor(conf.class_weights) if conf.class_weights else None
)
# Training targets
self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_encoder(
self.config.targets.classes,
replacement_rules=self.config.targets.replace,
)
self.decoder = build_decoder(self.config.targets.classes)
self.validation_predictions = []
self.example_input_array = torch.randn([1, 1, 128, 512])
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features)
size_preds = self.bbox(features)
return ModelOutput(
detection_probs=detection_probs,
size_preds=size_preds,
class_probs=classification_probs,
features=features,
)
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = compute_loss(
batch,
outputs,
conf=self.config.train.loss,
class_weights=self.class_weights,
)
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
self.log("train/loss/detection", losses.total, logger=True)
self.log("train/loss/size", losses.total, logger=True)
self.log("train/loss/classification", losses.total, logger=True)
return losses.total
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
outputs = self.forward(batch.spec)
losses = compute_loss(
batch,
outputs,
conf=self.config.train.loss,
class_weights=self.class_weights,
)
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
self.log("val/loss/detection", losses.total, logger=True)
self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True)
dataloaders = self.trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
clip_prediction = postprocess_model_outputs(
outputs,
clips=[clip_annotation.clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]
matches = match_predictions_and_annotations(
clip_annotation,
clip_prediction,
)
self.validation_predictions.extend(matches)
def on_validation_epoch_end(self) -> None:
self.validation_predictions.clear()
def configure_optimizers(self):
conf = self.config.train.optimizer
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
return [optimizer], [scheduler]
def process_clip(
self,
clip: data.Clip,
audio_dir: Optional[Path] = None,
) -> data.ClipPrediction:
spec = preprocess_audio_clip(
clip,
config=self.config.preprocessing,
audio_dir=audio_dir,
)
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
outputs = self.forward(tensor)
return postprocess_model_outputs(
outputs,
clips=[clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]

View File

@ -2,10 +2,10 @@
from typing import List, Optional, Tuple, Union, cast from typing import List, Optional, Tuple, Union, cast
import matplotlib.ticker as tick
import numpy as np import numpy as np
import torch import torch
from matplotlib import axes, patches from matplotlib import axes, patches
import matplotlib.ticker as tick
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
@ -102,7 +102,6 @@ def spectrogram(
return ax return ax
def spectrogram_with_detections( def spectrogram_with_detections(
spec: Union[torch.Tensor, np.ndarray], spec: Union[torch.Tensor, np.ndarray],
dets: List[Annotation], dets: List[Annotation],

View File

@ -17,6 +17,6 @@ def create_ax(
) -> axes.Axes: ) -> axes.Axes:
"""Create a new axis if none is provided""" """Create a new axis if none is provided"""
if ax is None: if ax is None:
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore _, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
return ax # type: ignore return ax # type: ignore

View File

@ -1,19 +1,20 @@
"""Module for postprocessing model outputs.""" """Module for postprocessing model outputs."""
from typing import Callable, List, Tuple, Union from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from pydantic import BaseModel, Field
import numpy as np import numpy as np
import torch import torch
from pydantic import Field
from soundevent import data from soundevent import data
from torch import nn from torch import nn
from batdetect2.data.labels import ClassMapper from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.typing import ModelOutput from batdetect2.models.typing import ModelOutput
__all__ = [ __all__ = [
"postprocess_model_outputs",
"PostprocessConfig", "PostprocessConfig",
"load_postprocess_config",
"postprocess_model_outputs",
] ]
NMS_KERNEL_SIZE = 9 NMS_KERNEL_SIZE = 9
@ -21,7 +22,7 @@ DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200 TOP_K_PER_SEC = 200
class PostprocessConfig(BaseModel): class PostprocessConfig(BaseConfig):
"""Configuration for postprocessing model outputs.""" """Configuration for postprocessing model outputs."""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
@ -31,14 +32,29 @@ class PostprocessConfig(BaseModel):
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
TagFunction = Callable[[int], List[data.Tag]] def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
return load_config(path, schema=PostprocessConfig, field=field)
class RawPrediction(NamedTuple):
start_time: float
end_time: float
low_freq: float
high_freq: float
detection_score: float
class_scores: Dict[str, float]
features: np.ndarray
def postprocess_model_outputs( def postprocess_model_outputs(
outputs: ModelOutput, outputs: ModelOutput,
clips: List[data.Clip], clips: List[data.Clip],
class_mapper: ClassMapper, classes: List[str],
config: PostprocessConfig, decoder: Callable[[str], List[data.Tag]],
config: Optional[PostprocessConfig] = None,
) -> List[data.ClipPrediction]: ) -> List[data.ClipPrediction]:
"""Postprocesses model outputs to generate clip predictions. """Postprocesses model outputs to generate clip predictions.
@ -68,6 +84,9 @@ def postprocess_model_outputs(
ValueError ValueError
If the number of predictions does not match the number of clips. If the number of predictions does not match the number of clips.
""" """
config = config or PostprocessConfig()
num_predictions = len(outputs.detection_probs) num_predictions = len(outputs.detection_probs)
if num_predictions == 0: if num_predictions == 0:
@ -108,7 +127,8 @@ def postprocess_model_outputs(
size_preds, size_preds,
class_probs, class_probs,
features, features,
class_mapper=class_mapper, classes=classes,
decoder=decoder,
min_freq=config.min_freq, min_freq=config.min_freq,
max_freq=config.max_freq, max_freq=config.max_freq,
detection_threshold=config.detection_threshold, detection_threshold=config.detection_threshold,
@ -124,6 +144,82 @@ def postprocess_model_outputs(
return predictions return predictions
def compute_predictions_from_outputs(
start: float,
end: float,
scores: torch.Tensor,
y_pos: torch.Tensor,
x_pos: torch.Tensor,
size_preds: torch.Tensor,
class_probs: torch.Tensor,
features: torch.Tensor,
classes: List[str],
min_freq: int = 10000,
max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD,
) -> List[RawPrediction]:
_, freq_bins, time_bins = size_preds.shape
sorted_indices = torch.argsort(x_pos)
valid_indices = sorted_indices[
scores[sorted_indices] > detection_threshold
]
scores = scores[valid_indices]
x_pos = x_pos[valid_indices]
y_pos = y_pos[valid_indices]
predictions: List[RawPrediction] = []
for score, x, y in zip(scores, x_pos, y_pos):
width, height = size_preds[:, y, x]
class_prob = class_probs[:, y, x].detach().numpy()
feats = features[:, y, x].detach().numpy()
start_time = np.interp(
x.item(),
[0, time_bins],
[start, end],
)
end_time = np.interp(
x.item() + width.item(),
[0, time_bins],
[start, end],
)
low_freq = np.interp(
y.item(),
[0, freq_bins],
[max_freq, min_freq],
)
high_freq = np.interp(
y.item() - height.item(),
[0, freq_bins],
[max_freq, min_freq],
)
start_time, end_time = sorted([float(start_time), float(end_time)])
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
predictions.append(
RawPrediction(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
detection_score=score.item(),
features=feats,
class_scores={
class_name: prob
for class_name, prob in zip(classes, class_prob)
},
)
)
return predictions
def compute_sound_events_from_outputs( def compute_sound_events_from_outputs(
clip: data.Clip, clip: data.Clip,
scores: torch.Tensor, scores: torch.Tensor,
@ -132,7 +228,8 @@ def compute_sound_events_from_outputs(
size_preds: torch.Tensor, size_preds: torch.Tensor,
class_probs: torch.Tensor, class_probs: torch.Tensor,
features: torch.Tensor, features: torch.Tensor,
class_mapper: ClassMapper, classes: List[str],
decoder: Callable[[str], List[data.Tag]],
min_freq: int = 10000, min_freq: int = 10000,
max_freq: int = 120000, max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD, detection_threshold: float = DETECTION_THRESHOLD,
@ -181,12 +278,13 @@ def compute_sound_events_from_outputs(
predicted_tags: List[data.PredictedTag] = [] predicted_tags: List[data.PredictedTag] = []
for label_id, class_score in enumerate(class_prob): for label_id, class_score in enumerate(class_prob):
corresponding_tags = class_mapper.inverse_transform(label_id) class_name = classes[label_id]
corresponding_tags = decoder(class_name)
predicted_tags.extend( predicted_tags.extend(
[ [
data.PredictedTag( data.PredictedTag(
tag=tag, tag=tag,
score=class_score.item(), score=max(min(class_score.item(), 1), 0),
) )
for tag in corresponding_tags for tag in corresponding_tags
] ]
@ -207,7 +305,7 @@ def compute_sound_events_from_outputs(
), ),
features=[ features=[
data.Feature( data.Feature(
name=f"batdetect2_{i}", term=data.term_from_key(f"batdetect2_{i}"),
value=value.item(), value=value.item(),
) )
for i, value in enumerate(feature) for i, value in enumerate(feature)
@ -217,7 +315,7 @@ def compute_sound_events_from_outputs(
predictions.append( predictions.append(
data.SoundEventPrediction( data.SoundEventPrediction(
sound_event=sound_event, sound_event=sound_event,
score=score.item(), score=max(min(score.item(), 1), 0),
tags=predicted_tags, tags=predicted_tags,
) )
) )

View File

@ -0,0 +1,68 @@
"""Module containing functions for preprocessing audio clips."""
from typing import Optional
import xarray as xr
from soundevent import data
from batdetect2.preprocess.audio import (
AudioConfig,
ResampleConfig,
load_clip_audio,
)
from batdetect2.preprocess.config import (
PreprocessingConfig,
load_preprocessing_config,
)
from batdetect2.preprocess.spectrogram import (
AmplitudeScaleConfig,
FrequencyConfig,
LogScaleConfig,
PcenScaleConfig,
Scales,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
compute_spectrogram,
)
__all__ = [
"AmplitudeScaleConfig",
"AudioConfig",
"FrequencyConfig",
"LogScaleConfig",
"PcenScaleConfig",
"PreprocessingConfig",
"ResampleConfig",
"STFTConfig",
"Scales",
"SpecSizeConfig",
"SpectrogramConfig",
"load_preprocessing_config",
"preprocess_audio_clip",
]
def preprocess_audio_clip(
clip: data.Clip,
config: Optional[PreprocessingConfig] = None,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Preprocesses audio clip to generate spectrogram.
Parameters
----------
clip
The audio clip to preprocess.
config
Configuration for preprocessing.
Returns
-------
xr.DataArray
Preprocessed spectrogram.
"""
config = config or PreprocessingConfig()
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
return compute_spectrogram(wav, config=config.spectrogram)

View File

@ -0,0 +1,61 @@
import numpy as np
def extend_width(
array: np.ndarray,
extra: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
dims = len(array.shape)
axis = axis % dims
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
return np.pad(
array,
pad,
mode="constant",
constant_values=value,
)
def make_width_divisible(
array: np.ndarray,
factor: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
width = array.shape[axis]
if width % factor == 0:
return array
extra = (-width) % factor
return extend_width(array, extra, axis=axis, value=value)
def adjust_width(
array: np.ndarray,
width: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
dims = len(array.shape)
axis = axis % dims
current_width = array.shape[axis]
if current_width == width:
return array
if current_width < width:
return extend_width(
array,
extra=width - current_width,
axis=axis,
value=value,
)
slices = [
slice(None, None) if index != axis else slice(None, width)
for index in range(dims)
]
return array[tuple(slices)]

View File

@ -0,0 +1,199 @@
from typing import Optional
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import arrays, audio, data
from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
TARGET_SAMPLERATE_HZ = 256_000
SCALE_RAW_AUDIO = False
DEFAULT_DURATION = None
class ResampleConfig(BaseConfig):
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
mode: str = "poly"
class AudioConfig(BaseConfig):
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
scale: bool = SCALE_RAW_AUDIO
center: bool = True
duration: Optional[float] = DEFAULT_DURATION
def load_file_audio(
path: data.PathLike,
config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
recording = data.Recording.from_file(path)
return load_recording_audio(
recording,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_recording_audio(
recording: data.Recording,
config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
return load_clip_audio(
clip,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_clip_audio(
clip: data.Clip,
config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
config = config or AudioConfig()
wav = (
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
)
if config.duration is not None:
wav = adjust_audio_duration(wav, duration=config.duration)
if config.resample:
wav = resample_audio(
wav,
samplerate=config.resample.samplerate,
dtype=dtype,
)
if config.center:
wav = ops.center(wav)
if config.scale:
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
return wav.astype(dtype)
def adjust_audio_duration(
wave: xr.DataArray,
duration: float,
) -> xr.DataArray:
start_time, end_time = arrays.get_dim_range(wave, dim="time")
current_duration = end_time - start_time
if current_duration == duration:
return wave
if current_duration > duration:
return arrays.crop_dim(
wave,
dim="time",
start=start_time,
stop=start_time + duration,
)
return arrays.extend_dim(
wave,
dim="time",
start=start_time,
stop=start_time + duration,
)
def resample_audio(
wav: xr.DataArray,
samplerate: int = TARGET_SAMPLERATE_HZ,
mode: str = "poly",
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension")
time_axis: int = wav.get_axis_num("time") # type: ignore
step = arrays.get_dim_step(wav, dim="time")
original_samplerate = int(1 / step)
if original_samplerate == samplerate:
return wav.astype(dtype)
if mode == "poly":
resampled = resample_audio_poly(
wav,
sr_orig=original_samplerate,
sr_new=samplerate,
axis=time_axis,
)
elif mode == "fourier":
resampled = resample_audio_fourier(
wav,
sr_orig=original_samplerate,
sr_new=samplerate,
axis=time_axis,
)
else:
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
start, stop = arrays.get_dim_range(wav, dim="time")
times = np.linspace(
start,
stop + step,
len(resampled),
endpoint=False,
dtype=dtype,
)
return xr.DataArray(
data=resampled.astype(dtype),
dims=wav.dims,
coords={
**wav.coords,
"time": arrays.create_time_dim_from_array(
times,
samplerate=samplerate,
),
},
attrs=wav.attrs,
)
def resample_audio_poly(
array: xr.DataArray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array.values,
sr_new // gcd,
sr_orig // gcd,
axis=axis,
)
def resample_audio_fourier(
array: xr.DataArray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
ratio = sr_new / sr_orig
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore

View File

@ -0,0 +1,31 @@
from typing import Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
AudioConfig,
)
from batdetect2.preprocess.spectrogram import (
SpectrogramConfig,
)
__all__ = [
"PreprocessingConfig",
"load_preprocessing_config",
]
class PreprocessingConfig(BaseConfig):
"""Configuration for preprocessing data."""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -0,0 +1,323 @@
from typing import Literal, Optional, Union
import librosa
import librosa.core.spectrum
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
from pydantic import Field
from soundevent import arrays, audio
from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
class STFTConfig(BaseConfig):
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
class FrequencyConfig(BaseConfig):
max_freq: int = Field(default=120_000, gt=0)
min_freq: int = Field(default=10_000, gt=0)
class SpecSizeConfig(BaseConfig):
height: int = 128
"""Height of the spectrogram in pixels. This value determines the
number of frequency bands and corresponds to the vertical dimension
of the spectrogram."""
resize_factor: Optional[float] = 0.5
"""Factor by which to resize the spectrogram along the time axis.
A value of 0.5 reduces the temporal dimension by half, while a
value of 2.0 doubles it. If None, no resizing is performed."""
class LogScaleConfig(BaseConfig):
name: Literal["log"] = "log"
class PcenScaleConfig(BaseConfig):
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
hop_length: int = 512
gain: float = 0.98
bias: float = 2
power: float = 0.5
class AmplitudeScaleConfig(BaseConfig):
name: Literal["amplitude"] = "amplitude"
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
scale: Scales = Field(
default_factory=PcenScaleConfig,
discriminator="name",
)
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
denoise: bool = True
max_scale: bool = False
def compute_spectrogram(
wav: xr.DataArray,
config: Optional[SpectrogramConfig] = None,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
config = config or SpectrogramConfig()
spec = stft(
wav,
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
dtype=dtype,
)
spec = crop_spectrogram_frequencies(
spec,
min_freq=config.frequencies.min_freq,
max_freq=config.frequencies.max_freq,
)
spec = scale_spectrogram(spec, scale=config.scale)
if config.denoise:
spec = denoise_spectrogram(spec)
if config.size:
spec = resize_spectrogram(
spec,
height=config.size.height,
resize_factor=config.size.resize_factor,
)
if config.max_scale:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
return spec.astype(dtype)
def crop_spectrogram_frequencies(
spec: xr.DataArray,
min_freq: int = 10_000,
max_freq: int = 120_000,
) -> xr.DataArray:
return arrays.crop_dim(
spec,
dim="frequency",
start=min_freq,
stop=max_freq,
).astype(spec.dtype)
def stft(
wave: xr.DataArray,
window_duration: float,
window_overlap: float,
window_fn: str = "hann",
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
start_time, end_time = arrays.get_dim_range(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time")
sampling_rate = 1 / step
nfft = int(window_duration * sampling_rate)
noverlap = int(window_overlap * nfft)
hop_len = nfft - noverlap
hop_duration = hop_len / sampling_rate
spec, _ = librosa.core.spectrum._spectrogram(
y=wave.data.astype(dtype),
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
window=window_fn,
)
return xr.DataArray(
data=spec.astype(dtype),
dims=["frequency", "time"],
coords={
"frequency": arrays.create_frequency_dim_from_array(
np.linspace(
0,
sampling_rate / 2,
spec.shape[0],
endpoint=False,
dtype=dtype,
),
step=sampling_rate / nfft,
),
"time": arrays.create_time_dim_from_array(
np.linspace(
start_time,
end_time - (window_duration - hop_duration),
spec.shape[1],
endpoint=False,
dtype=dtype,
),
step=hop_duration,
),
},
attrs={
**wave.attrs,
"original_samplerate": sampling_rate,
"nfft": nfft,
"noverlap": noverlap,
},
)
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
return xr.DataArray(
data=(spec - spec.mean("time")).clip(0),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,
)
def scale_spectrogram(
spec: xr.DataArray,
scale: Scales,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
if scale.name == "log":
return scale_log(spec, dtype=dtype)
if scale.name == "pcen":
return scale_pcen(
spec,
time_constant=scale.time_constant,
hop_length=scale.hop_length,
gain=scale.gain,
power=scale.power,
bias=scale.bias,
)
return spec
def scale_pcen(
spec: xr.DataArray,
time_constant: float = 0.4,
hop_length: int = 512,
gain: float = 0.98,
bias: float = 2,
power: float = 0.5,
) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"]
t_frames = time_constant * samplerate / (float(hop_length) * 10)
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
return audio.pcen(
spec * (2**31),
smooth=smoothing_constant,
gain=gain,
bias=bias,
power=power,
).astype(spec.dtype)
def scale_log(
spec: xr.DataArray,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"]
nfft = spec.attrs["nfft"]
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
return xr.DataArray(
data=np.log1p(log_scaling * spec).astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,
)
def resize_spectrogram(
spec: xr.DataArray,
height: int = 128,
resize_factor: Optional[float] = 0.5,
) -> xr.DataArray:
resize_factor = resize_factor or 1
current_width = spec.sizes["time"]
return ops.resize(
spec,
time=int(resize_factor * current_width),
frequency=height,
dtype=np.float32,
)
def adjust_spectrogram_width(
spec: xr.DataArray,
divide_factor: int = 32,
time_period: float = 0.001,
) -> xr.DataArray:
time_width = spec.sizes["time"]
if time_width % divide_factor == 0:
return spec
target_size = int(
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
)
extra_duration = (target_size - time_width) * time_period
_, stop = arrays.get_dim_range(spec, dim="time")
resized = ops.extend_dim(
spec,
dim="time",
stop=stop + extra_duration,
)
return resized
def duration_to_spec_width(
duration: float,
samplerate: int,
window_duration: float,
window_overlap: float,
) -> int:
samples = int(duration * samplerate)
fft_len = int(window_duration * samplerate)
fft_overlap = int(window_overlap * fft_len)
hop_len = fft_len - fft_overlap
width = (samples - fft_len + hop_len) / hop_len
return int(np.floor(width))
def spec_width_to_samples(
width: int,
samplerate: int,
window_duration: float,
window_overlap: float,
) -> int:
fft_len = int(window_duration * samplerate)
fft_overlap = int(window_overlap * fft_len)
hop_len = fft_len - fft_overlap
return width * hop_len + fft_len - hop_len
def get_spectrogram_resolution(
config: SpectrogramConfig,
) -> tuple[float, float]:
max_freq = config.frequencies.max_freq
min_freq = config.frequencies.min_freq
assert config.size is not None
spec_height = config.size.height
resize_factor = config.size.resize_factor or 1
freq_bin_width = (max_freq - min_freq) / spec_height
hop_duration = config.stft.window_duration * (
1 - config.stft.window_overlap
)
return freq_bin_width, hop_duration / resize_factor

View File

@ -0,0 +1,76 @@
from typing import Union
import numpy as np
import torch
from torch.nn import functional as F
def extend_width(
array: Union[np.ndarray, torch.Tensor],
extra: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
if not isinstance(array, torch.Tensor):
array = torch.Tensor(array)
dims = len(array.shape)
axis = axis % dims
pad = [
[0, 0] if index != axis else [0, extra]
for index in range(axis, dims)[::-1]
]
return F.pad(
array,
[x for y in pad for x in y],
value=value,
)
def make_width_divisible(
array: Union[np.ndarray, torch.Tensor],
factor: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
if not isinstance(array, torch.Tensor):
array = torch.Tensor(array)
width = array.shape[axis]
if width % factor == 0:
return array
extra = (-width) % factor
return extend_width(array, extra, axis=axis, value=value)
def adjust_width(
array: Union[np.ndarray, torch.Tensor],
width: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
if not isinstance(array, torch.Tensor):
array = torch.Tensor(array)
dims = len(array.shape)
axis = axis % dims
current_width = array.shape[axis]
if current_width == width:
return array
if current_width < width:
return extend_width(
array,
extra=width - current_width,
axis=axis,
value=value,
)
slices = [
slice(None, None) if index != axis else slice(None, width)
for index in range(dims)
]
return array[tuple(slices)]

88
batdetect2/terms.py Normal file
View File

@ -0,0 +1,88 @@
from inspect import getmembers
from typing import Optional
from pydantic import BaseModel
from soundevent import data, terms
__all__ = [
"call_type",
"individual",
"get_term_from_info",
"get_tag_from_info",
"TermInfo",
"TagInfo",
]
class TermInfo(BaseModel):
label: Optional[str]
name: Optional[str]
uri: Optional[str]
class TagInfo(BaseModel):
value: str
term: Optional[TermInfo] = None
key: Optional[str] = None
label: Optional[str] = None
call_type = data.Term(
name="soundevent:call_type",
label="Call Type",
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
)
individual = data.Term(
name="soundevent:individual",
label="Individual",
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
)
ALL_TERMS = [
*getmembers(terms, lambda x: isinstance(x, data.Term)),
call_type,
individual,
]
def get_term_from_info(term_info: TermInfo) -> data.Term:
for term in ALL_TERMS:
if term_info.name and term_info.name == term.name:
return term
if term_info.label and term_info.label == term.label:
return term
if term_info.uri and term_info.uri == term.uri:
return term
if term_info.name is None:
if term_info.label is None:
raise ValueError("At least one of name or label must be provided.")
term_info.name = (
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
)
if term_info.label is None:
term_info.label = term_info.name
return data.Term(
name=term_info.name,
label=term_info.label,
uri=term_info.uri,
definition="Unknown",
)
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
if tag_info.term:
term = get_term_from_info(tag_info.term)
elif tag_info.key:
term = data.term_from_key(tag_info.key)
else:
raise ValueError("Either term or key must be provided in tag info.")
return data.Tag(term=term, value=tag_info.value)

View File

@ -0,0 +1,48 @@
from batdetect2.train.augmentations import (
AugmentationsConfig,
add_echo,
augment_example,
load_agumentation_config,
mask_frequency,
mask_time,
mix_examples,
scale_volume,
select_subclip,
warp_spectrogram,
)
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import (
LabeledDataset,
SubclipConfig,
TrainExample,
)
from batdetect2.train.labels import LabelConfig, load_label_config
from batdetect2.train.preprocess import preprocess_annotations
from batdetect2.train.targets import TargetConfig, load_target_config
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [
"AugmentationsConfig",
"LabelConfig",
"LabeledDataset",
"SubclipConfig",
"TargetConfig",
"TrainExample",
"TrainerConfig",
"TrainingConfig",
"add_echo",
"augment_example",
"load_agumentation_config",
"load_label_config",
"load_target_config",
"load_train_config",
"load_trainer_config",
"mask_frequency",
"mask_time",
"mix_examples",
"preprocess_annotations",
"scale_volume",
"select_subclip",
"train",
"warp_spectrogram",
]

View File

@ -1,941 +0,0 @@
"""Functions and dataloaders for training and testing the model."""
import copy
from typing import List, Optional, Tuple
import librosa
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
import batdetect2.utils.audio_utils as au
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
AudioLoaderParameters,
FileAnnotation,
)
def generate_gt_heatmaps(
spec_op_shape: Tuple[int, int],
sampling_rate: float,
ann: AudioLoaderAnnotationGroup,
class_names: List[str],
fft_win_length: float,
fft_overlap: float,
max_freq: float,
min_freq: float,
resize_factor: float,
target_sigma: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
"""Generate ground truth heatmaps from annotations.
Parameters
----------
spec_op_shape : Tuple[int, int]
Shape of the input spectrogram.
sampling_rate : int
Sampling rate of the input audio in Hz.
ann : AnnotationGroup
Dictionary containing the annotation information.
params : HeatmapParameters
Parameters controlling the generation of the heatmaps.
Returns
-------
y_2d_det : np.ndarray
2D heatmap of the presence of an event.
y_2d_size : np.ndarray
2D heatmap of the size of the bounding box associated to event.
y_2d_classes : np.ndarray
3D array containing the ground-truth class probabilities for each
pixel.
ann_aug : AnnotationGroup
A dictionary containing the annotation information of the
annotations that are within the input spectrogram, augmented with
the x and y indices of their pixel location in the input spectrogram.
"""
# spec may be resized on input into the network
num_classes = len(class_names)
op_height = spec_op_shape[0]
op_width = spec_op_shape[1]
freq_per_bin = (max_freq - min_freq) / op_height
# start and end times
x_pos_start = au.time_to_x_coords(
ann["start_times"],
sampling_rate,
fft_win_length,
fft_overlap,
)
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
x_pos_end = au.time_to_x_coords(
ann["end_times"],
sampling_rate,
fft_win_length,
fft_overlap,
)
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
# location on y axis i.e. frequency
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
y_pos_low = (op_height - y_pos_low).astype(np.int32)
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int32)
bb_widths = x_pos_end - x_pos_start
bb_heights = y_pos_low - y_pos_high
# Only include annotations that are within the input spectrogram
valid_inds = np.where(
(x_pos_start >= 0)
& (x_pos_start < op_width)
& (y_pos_low >= 0)
& (y_pos_low < (op_height - 1))
)[0]
ann_aug: AudioLoaderAnnotationGroup = {
**ann,
"start_times": ann["start_times"][valid_inds],
"end_times": ann["end_times"][valid_inds],
"high_freqs": ann["high_freqs"][valid_inds],
"low_freqs": ann["low_freqs"][valid_inds],
"class_ids": ann["class_ids"][valid_inds],
"individual_ids": ann["individual_ids"][valid_inds],
"x_inds": x_pos_start[valid_inds],
"y_inds": y_pos_low[valid_inds],
}
# if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage
if len(ann_aug["individual_ids"]) == 1:
ann_aug["individual_ids"][0] = 0
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class
y_2d_classes: np.ndarray = np.zeros(
(num_classes + 1, op_height, op_width), dtype=np.float32
)
# create 2D ground truth heatmaps
for ii in valid_inds:
draw_gaussian(
y_2d_det[0, :],
(x_pos_start[ii], y_pos_low[ii]),
target_sigma,
)
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
cls_id = ann["class_ids"][ii]
if cls_id > -1:
draw_gaussian(
y_2d_classes[cls_id, :],
(x_pos_start[ii], y_pos_low[ii]),
target_sigma,
)
# be careful as this will have a 1.0 places where we have event but
# dont know gt class this will be masked in training anyway
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
def draw_gaussian(
heatmap: np.ndarray,
center: Tuple[int, int],
sigmax: float,
sigmay: Optional[float] = None,
) -> bool:
"""Draw a 2D gaussian into the heatmap.
If the gaussian center is outside the heatmap, then the gaussian is not
drawn.
Parameters
----------
heatmap : np.ndarray
The heatmap to draw into. Should be of shape (height, width).
center : Tuple[int, int]
The center of the gaussian in (x, y) format.
sigmax : float
The standard deviation of the gaussian in the x direction.
sigmay : Optional[float], optional
The standard deviation of the gaussian in the y direction. If None,
then sigmay = sigmax, by default None.
Returns
-------
bool
True if the gaussian was drawn, False if it was not (because
the center was outside the heatmap).
"""
# center is (x, y)
# this edits the heatmap inplace
if sigmay is None:
sigmay = sigmax
tmp_size = np.maximum(sigmax, sigmay) * 3
mu_x = int(center[0] + 0.5)
mu_y = int(center[1] + 0.5)
w, h = heatmap.shape[0], heatmap.shape[1]
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
return False
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
g = np.exp(
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
)
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
img_x = max(0, ul[0]), min(br[0], h)
img_y = max(0, ul[1]), min(br[1], w)
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
)
return True
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
"""Pad array with -1s."""
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
def warp_spec_aug(
spec: torch.Tensor,
ann: AudioLoaderAnnotationGroup,
stretch_squeeze_delta: float,
) -> torch.Tensor:
"""Warp spectrogram by randomly stretching and squeezing.
Parameters
----------
spec: torch.Tensor
Spectrogram to warp.
ann: AnnotationGroup
Annotation group for the spectrogram. Must be provided to sync
the start and stop times with the spectrogram after warping.
stretch_squeeze_delta: float
Maximum amount to stretch or squeeze the spectrogram.
Returns
-------
torch.Tensor
Warped spectrogram.
Notes
-----
This function modifies the annotation group in place.
"""
# Augment spectrogram by randomly stretch and squeezing
# NOTE this also changes the start and stop time in place
delta = stretch_squeeze_delta
op_size = (spec.shape[1], spec.shape[2])
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
resize_amt = int(spec.shape[2] * resize_fract_r)
if resize_amt >= spec.shape[2]:
spec_r = torch.cat(
(
spec,
torch.zeros(
(1, spec.shape[1], resize_amt - spec.shape[2]),
dtype=spec.dtype,
),
),
dim=2,
)
else:
spec_r = spec[:, :, :resize_amt]
# Resize the spectrogram
spec = F.interpolate(
spec_r.unsqueeze(0),
size=op_size,
mode="bilinear",
align_corners=False,
).squeeze(0)
# Update the start and stop times
ann["start_times"] *= 1.0 / resize_fract_r
ann["end_times"] *= 1.0 / resize_fract_r
return spec
def mask_time_aug(
spec: torch.Tensor,
mask_max_time_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of time.
Will randomly mask out a block of time in the spectrogram. The block
will be between 0.0 and `mask_max_time_perc` of the total time.
A random number of blocks will be masked out between 1 and 3.
Parameters
----------
spec: torch.Tensor
Spectrogram to mask.
mask_max_time_perc: float
Maximum percentage of time to mask out.
Returns
-------
torch.Tensor
Spectrogram with masked out time blocks.
Notes
-----
This function is based on the implementation in::
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition
"""
fm = torchaudio.transforms.TimeMasking(
int(spec.shape[1] * mask_max_time_perc)
)
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def mask_freq_aug(
spec: torch.Tensor,
mask_max_freq_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of frequency.
Will randomly mask out a block of frequency in the spectrogram. The block
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
A random number of blocks will be masked out between 1 and 3.
Parameters
----------
spec: torch.Tensor
Spectrogram to mask.
mask_max_freq_perc: float
Maximum percentage of frequency to mask out.
Returns
-------
torch.Tensor
Spectrogram with masked out frequency blocks.
Notes
-----
This function is based on the implementation in::
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition
"""
fm = torchaudio.transforms.FrequencyMasking(
int(spec.shape[1] * mask_max_freq_perc)
)
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def scale_vol_aug(
spec: torch.Tensor,
spec_amp_scaling: float,
) -> torch.Tensor:
"""Scale the volume of the spectrogram.
Parameters
----------
spec: torch.Tensor
Spectrogram to scale.
spec_amp_scaling: float
Maximum scaling factor.
Returns
-------
torch.Tensor
"""
return spec * np.random.random() * spec_amp_scaling
def echo_aug(
audio: np.ndarray,
sampling_rate: float,
echo_max_delay: float,
) -> np.ndarray:
"""Add echo to audio.
Parameters
----------
audio: np.ndarray
Audio to add echo to.
sampling_rate: float
Sampling rate of the audio.
echo_max_delay: float
Maximum delay of the echo in seconds.
Returns
-------
np.ndarray
Audio with echo added.
"""
sample_offset = (
int(echo_max_delay * np.random.random() * sampling_rate) + 1
)
# NOTE: This seems to be wrong, as the echo should be added to the
# end of the audio, not the beginning.
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio
def resample_aug(
audio: np.ndarray,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
resize_factor: float,
spec_divide_factor: float,
spec_train_width: int,
aug_sampling_rates: List[int],
) -> Tuple[np.ndarray, float, float]:
"""Resample audio augmentation.
Will resample the audio to a random sampling rate from the list of
sampling rates in `aug_sampling_rates`.
Parameters
----------
audio: np.ndarray
Audio to resample.
sampling_rate: float
Original sampling rate of the audio.
fft_win_length: float
Length of the FFT window in seconds.
fft_overlap: float
Amount of overlap between FFT windows.
resize_factor: float
Factor to resize the spectrogram by.
spec_divide_factor: float
Factor to divide the spectrogram by.
spec_train_width: int
Width of the spectrogram.
aug_sampling_rates: List[int]
List of sampling rates to resample to.
Returns
-------
audio : np.ndarray
Resampled audio.
sampling_rate : float
New sampling rate.
duration : float
Duration of the audio in seconds.
"""
sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(aug_sampling_rates)
audio = librosa.resample(
audio,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
audio = au.pad_audio(
audio,
sampling_rate,
fft_win_length,
fft_overlap,
resize_factor,
spec_divide_factor,
spec_train_width,
)
duration = audio.shape[0] / float(sampling_rate)
return audio, sampling_rate, duration
def resample_audio(
num_samples: int,
sampling_rate: float,
audio2: np.ndarray,
sampling_rate2: float,
) -> Tuple[np.ndarray, float]:
"""Resample audio.
Parameters
----------
num_samples: int
Expected number of samples for the output audio.
sampling_rate: float
Original sampling rate of the audio.
audio2: np.ndarray
Audio to resample.
sampling_rate2: float
Target sampling rate of the audio.
Returns
-------
audio2 : np.ndarray
Resampled audio.
sampling_rate2 : float
New sampling rate.
"""
# resample to target sampling rate
if sampling_rate != sampling_rate2:
audio2 = librosa.resample(
audio2,
orig_sr=sampling_rate2,
target_sr=sampling_rate,
res_type="polyphase",
)
sampling_rate2 = sampling_rate
# pad or trim to the correct length
if audio2.shape[0] < num_samples:
audio2 = np.hstack(
(
audio2,
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
)
)
elif audio2.shape[0] > num_samples:
audio2 = audio2[:num_samples]
return audio2, sampling_rate2
def combine_audio_aug(
audio: np.ndarray,
sampling_rate: float,
ann: AudioLoaderAnnotationGroup,
audio2: np.ndarray,
sampling_rate2: float,
ann2: AudioLoaderAnnotationGroup,
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
"""Combine two audio files.
Will combine two audio files by resampling them to the same sampling rate
and then combining them with a random weight. The annotations will be
combined by taking the union of the two sets of annotations.
Parameters
----------
audio: np.ndarray
First Audio to combine.
sampling_rate: int
Sampling rate of the first audio.
ann: AnnotationGroup
Annotations for the first audio.
audio2: np.ndarray
Second Audio to combine.
sampling_rate2: int
Sampling rate of the second audio.
ann2: AnnotationGroup
Annotations for the second audio.
Returns
-------
audio : np.ndarray
Combined audio.
ann : AnnotationGroup
Combined annotations.
"""
# resample so they are the same
audio2, sampling_rate2 = resample_audio(
audio.shape[0],
sampling_rate,
audio2,
sampling_rate2,
)
# # set mean and std to be the same
# audio2 = (audio2 - audio2.mean())
# audio2 = (audio2/audio2.std())*audio.std()
# audio2 = audio2 + audio.mean()
if (
ann.get("annotated", False)
and (ann2.get("annotated", False))
and (sampling_rate2 == sampling_rate)
and (audio.shape[0] == audio2.shape[0])
):
comb_weight = 0.3 + np.random.random() * 0.4
audio = comb_weight * audio + (1 - comb_weight) * audio2
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
for kk in ann.keys():
# when combining calls from different files, assume they come
# from different individuals
if kk == "individual_ids":
if (ann[kk] > -1).sum() > 0:
ann2[kk][ann2[kk] > -1] += (
np.max(ann[kk][ann[kk] > -1]) + 1
)
if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
return audio, ann
def _prepare_annotation(
annotation: Annotation,
class_names: List[str],
) -> Annotation:
try:
class_id = class_names.index(annotation["class"])
except ValueError:
class_id = -1
ann: Annotation = {
**annotation,
"class_id": class_id,
}
if "individual" in ann:
ann["individual"] = int(ann["individual"]) # type: ignore
return ann
def _prepare_file_annotation(
annotation: FileAnnotation,
class_names: List[str],
classes_to_ignore: List[str],
) -> AudioLoaderAnnotationGroup:
annotations = [
_prepare_annotation(ann, class_names)
for ann in annotation["annotation"]
if ann["class"] not in classes_to_ignore
]
try:
class_id_file = class_names.index(annotation["class_name"])
except ValueError:
class_id_file = -1
ret: AudioLoaderAnnotationGroup = {
"id": annotation["id"],
"annotated": annotation["annotated"],
"duration": annotation["duration"],
"issues": annotation["issues"],
"time_exp": annotation["time_exp"],
"class_name": annotation["class_name"],
"notes": annotation["notes"],
"annotation": annotations,
"start_times": np.array([ann["start_time"] for ann in annotations]),
"end_times": np.array([ann["end_time"] for ann in annotations]),
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
"class_ids": np.array(
[ann.get("class_id", -1) for ann in annotations]
),
"individual_ids": np.array([ann["individual"] for ann in annotations]),
"class_id_file": class_id_file,
}
return ret
class AudioLoader(torch.utils.data.Dataset):
"""Main AudioLoader for training and testing."""
def __init__(
self,
data_anns_ip: List[FileAnnotation],
params: AudioLoaderParameters,
dataset_name: Optional[str] = None,
is_train: bool = False,
return_spec_for_viz: bool = False,
):
self.is_train = is_train
self.params = params
self.return_spec_for_viz = return_spec_for_viz
self.data_anns: List[AudioLoaderAnnotationGroup] = [
_prepare_file_annotation(
ann,
params["class_names"],
params["classes_to_ignore"],
)
for ann in data_anns_ip
]
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
self.max_num_anns = 2 * np.max(
ann_cnt
) # x2 because we may be combining files during training
print("\n")
if dataset_name is not None:
print("Dataset : " + dataset_name)
if self.is_train:
print("Split type : train")
else:
print("Split type : test")
print("Num files : " + str(len(self.data_anns)))
print("Num calls : " + str(np.sum(ann_cnt)))
def get_file_and_anns(
self,
index: Optional[int] = None,
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
"""Get an audio file and its annotations.
Parameters
----------
index : int, optional
Index of the file to be loaded. If None, a random file is chosen.
Returns
-------
audio_raw : np.ndarray
Loaded audio file.
sampling_rate : float
Sampling rate of the audio file.
duration : float
Duration of the audio file in seconds.
ann : AnnotationGroup
AnnotationGroup object containing the annotations for the audio file.
"""
# if no file specified, choose random one
if index is None:
index = np.random.randint(0, len(self.data_anns))
audio_file = self.data_anns[index]["file_path"]
sampling_rate, audio_raw = au.load_audio(
audio_file,
self.data_anns[index]["time_exp"],
self.params["target_samp_rate"],
self.params["scale_raw_audio"],
)
# copy annotation
ann = copy.deepcopy(self.data_anns[index])
# ann["annotated"] = self.data_anns[index]["annotated"]
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
# keys = [
# "start_times",
# "end_times",
# "high_freqs",
# "low_freqs",
# "class_ids",
# "individual_ids",
# ]
# for kk in keys:
# ann[kk] = self.data_anns[index][kk].copy()
# if train then grab a random crop
if self.is_train:
nfft = int(self.params["fft_win_length"] * sampling_rate)
noverlap = int(self.params["fft_overlap"] * nfft)
length_samples = (
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
)
if audio_raw.shape[0] - length_samples > 0:
sample_crop = np.random.randint(
audio_raw.shape[0] - length_samples
)
else:
sample_crop = 0
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
ann["start_times"] = ann["start_times"] - sample_crop / float(
sampling_rate
)
ann["end_times"] = ann["end_times"] - sample_crop / float(
sampling_rate
)
# pad audio
if self.is_train:
op_spec_target_size = self.params["spec_train_width"]
else:
op_spec_target_size = None
audio_raw = au.pad_audio(
audio_raw,
sampling_rate,
self.params["fft_win_length"],
self.params["fft_overlap"],
self.params["resize_factor"],
self.params["spec_divide_factor"],
op_spec_target_size,
)
duration = audio_raw.shape[0] / float(sampling_rate)
# sort based on time
inds = np.argsort(ann["start_times"])
for kk in ann.keys():
if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = ann[kk][inds]
return audio_raw, sampling_rate, duration, ann
def __getitem__(self, index):
"""Get an item from the dataset."""
# load audio file
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
# augment on raw audio
if self.is_train and self.params["augment_at_train"]:
# augment - combine with random audio file
if (
self.params["augment_at_train_combine"]
and np.random.random() < self.params["aug_prob"]
):
(
audio2,
sampling_rate2,
_,
ann2,
) = self.get_file_and_anns()
audio, ann = combine_audio_aug(
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
)
# simulate echo by adding delayed copy of the file
if np.random.random() < self.params["aug_prob"]:
audio = echo_aug(
audio,
sampling_rate,
echo_max_delay=self.params["echo_max_delay"],
)
# resample the audio
# if np.random.random() < self.params["aug_prob"]:
# audio, sampling_rate, duration = resample_aug(
# audio, sampling_rate, self.params
# )
# create spectrogram
spec, _ = au.generate_spectrogram(
audio,
sampling_rate,
params=dict(
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
spec_scale=self.params["spec_scale"],
denoise_spec_avg=self.params["denoise_spec_avg"],
max_scale_spec=self.params["max_scale_spec"],
),
)
rsf = self.params["resize_factor"]
spec_op_shape = (
int(self.params["spec_height"] * rsf),
int(spec.shape[1] * rsf),
)
# resize the spec
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
spec = F.interpolate(
spec,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
).squeeze(0)
# augment spectrogram
if self.is_train and self.params["augment_at_train"]:
if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(
spec,
spec_amp_scaling=self.params["spec_amp_scaling"],
)
if np.random.random() < self.params["aug_prob"]:
spec = warp_spec_aug(
spec,
ann,
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
)
if np.random.random() < self.params["aug_prob"]:
spec = mask_time_aug(
spec,
mask_max_time_perc=self.params["mask_max_time_perc"],
)
if np.random.random() < self.params["aug_prob"]:
spec = mask_freq_aug(
spec,
mask_max_freq_perc=self.params["mask_max_freq_perc"],
)
outputs = {}
outputs["spec"] = spec
if self.return_spec_for_viz:
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
0
)
# create ground truth heatmaps
(
outputs["y_2d_det"],
outputs["y_2d_size"],
outputs["y_2d_classes"],
ann_aug,
) = generate_gt_heatmaps(
spec_op_shape,
sampling_rate,
ann,
class_names=self.params["class_names"],
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
resize_factor=self.params["resize_factor"],
target_sigma=self.params["target_sigma"],
)
# hack to get around requirement that all vectors are the same length
# in the output batch
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
outputs["is_valid"] = pad_aray(
np.ones(len(ann_aug["individual_ids"])), pad_size
)
keys = [
"class_ids",
"individual_ids",
"x_inds",
"y_inds",
"start_times",
"end_times",
"low_freqs",
"high_freqs",
]
for kk in keys:
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
# convert to pytorch
for kk in outputs.keys():
if type(outputs[kk]) != torch.Tensor:
outputs[kk] = torch.from_numpy(outputs[kk])
# scalars
outputs["class_id_file"] = ann["class_id_file"]
outputs["annotated"] = ann["annotated"]
outputs["duration"] = duration
outputs["sampling_rate"] = sampling_rate
outputs["file_id"] = index
return outputs
def __len__(self):
"""Denotes the total number of samples."""
return len(self.data_anns)

View File

@ -1,136 +1,214 @@
from functools import wraps from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Tuple
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from pydantic import Field
from soundevent.geometry import compute_bounds from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
from batdetect2.preprocess.arrays import adjust_width
Augmentation = Callable[[xr.Dataset], xr.Dataset] Augmentation = Callable[[xr.Dataset], xr.Dataset]
AUGMENTATION_PROBABILITY = 0.2 __all__ = [
MAX_DELAY = 0.005 "AugmentationsConfig",
STRETCH_SQUEEZE_DELTA = 0.04 "load_agumentation_config",
MASK_MAX_TIME_PERC: float = 0.05 "select_subclip",
MASK_MAX_FREQ_PERC: float = 0.10 "mix_examples",
"add_echo",
"scale_volume",
"warp_spectrogram",
"mask_time",
"mask_frequency",
"augment_example",
]
def maybe_apply( class BaseAugmentationConfig(BaseConfig):
augmentation: Callable, enable: bool = True
prob: float = AUGMENTATION_PROBABILITY, probability: float = 0.2
) -> Callable:
"""Apply an augmentation with a given probability."""
@wraps(augmentation)
def _augmentation(x):
if np.random.rand() > prob:
return x
return augmentation(x)
return _augmentation
def select_random_subclip( def select_subclip(
train_example: xr.Dataset, example: xr.Dataset,
start_time: Optional[float] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
proportion: float = 0.9, width: Optional[int] = None,
random: bool = False,
) -> xr.Dataset: ) -> xr.Dataset:
"""Select a random subclip from a clip.""" """Select a random subclip from a clip."""
step = arrays.get_dim_step(example, "time") # type: ignore
start, stop = arrays.get_dim_range(example, "time") # type: ignore
time_coords = train_example.coords["time"] if width is None:
if duration is None:
raise ValueError("Either duration or width must be provided")
start_time = time_coords.attrs.get("min", time_coords.min()) width = int(np.floor(duration / step))
end_time = time_coords.attrs.get("max", time_coords.max())
if duration is None: if duration is None:
duration = (end_time - start_time) * proportion duration = width * step
start_time = np.random.uniform(start_time, end_time - duration) if start_time is None:
return train_example.sel(time=slice(start_time, start_time + duration)) if random:
start_time = np.random.uniform(start, max(stop - duration, start))
else:
start_time = start
if start_time + duration > stop:
return example
start_index = arrays.get_coord_index(
example, # type: ignore
"time",
start_time,
)
end_index = start_index + width - 1
start_time = example.time.values[start_index]
end_time = example.time.values[end_index]
return example.sel(
time=slice(start_time, end_time),
audio_time=slice(start_time, end_time + step),
)
def combine_audio( class MixAugmentationConfig(BaseAugmentationConfig):
audio1: xr.DataArray, min_weight: float = 0.3
audio2: xr.DataArray, max_weight: float = 0.7
alpha: Optional[float] = None,
min_alpha: float = 0.3,
max_alpha: float = 0.7, def mix_examples(
) -> xr.DataArray: example: xr.Dataset,
other: xr.Dataset,
weight: Optional[float] = None,
min_weight: float = 0.3,
max_weight: float = 0.7,
config: Optional[PreprocessingConfig] = None,
) -> xr.Dataset:
"""Combine two audio clips.""" """Combine two audio clips."""
config = config or PreprocessingConfig()
if alpha is None: if weight is None:
alpha = np.random.uniform(min_alpha, max_alpha) weight = np.random.uniform(min_weight, max_weight)
return alpha * audio1 + (1 - alpha) * audio2.data audio1 = example["audio"]
audio2 = adjust_width(other["audio"].values, len(audio1))
combined = weight * audio1 + (1 - weight) * audio2
spectrogram = compute_spectrogram(
combined.rename({"audio_time": "time"}),
config=config.spectrogram,
).data
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
previous_width = len(example["time"])
spectrogram = adjust_width(spectrogram, previous_width)
detection_heatmap = xr.apply_ufunc(
np.maximum,
example["detection"],
adjust_width(other["detection"].values, previous_width),
)
class_heatmap = xr.apply_ufunc(
np.maximum,
example["class"],
adjust_width(other["class"].values, previous_width),
)
size_heatmap = example["size"] + adjust_width(
other["size"].values, previous_width
)
return xr.Dataset(
{
"audio": combined,
"spectrogram": xr.DataArray(
data=spectrogram,
dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords,
),
"detection": detection_heatmap,
"class": class_heatmap,
"size": size_heatmap,
},
attrs=example.attrs,
)
# def random_mix( class EchoAugmentationConfig(BaseAugmentationConfig):
# audio: xr.DataArray, max_delay: float = 0.005
# clip: data.ClipAnnotation, min_weight: float = 0.0
# provider: Optional[ClipProvider] = None, max_weight: float = 1.0
# alpha: Optional[float] = None,
# min_alpha: float = 0.3,
# max_alpha: float = 0.7,
# join_annotations: bool = True,
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
# """Mix two audio clips."""
# if provider is None:
# raise ValueError("No audio provider given.")
#
# try:
# other_audio, other_clip = provider(clip)
# except (StopIteration, ValueError):
# raise ValueError("No more audio sources available.")
#
# new_audio = combine_audio(
# audio,
# other_audio,
# alpha=alpha,
# min_alpha=min_alpha,
# max_alpha=max_alpha,
# )
#
# if join_annotations:
# clip = clip.model_copy(
# update=dict(
# sound_events=clip.sound_events + other_clip.sound_events,
# )
# )
#
# return new_audio, clip
def add_echo( def add_echo(
train_example: xr.Dataset, example: xr.Dataset,
delay: Optional[float] = None, delay: Optional[float] = None,
alpha: Optional[float] = None, weight: Optional[float] = None,
min_alpha: float = 0.0, min_weight: float = 0.1,
max_alpha: float = 1.0, max_weight: float = 1.0,
max_delay: float = MAX_DELAY, max_delay: float = 0.005,
config: Optional[PreprocessingConfig] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Add a delay to the audio.""" """Add a delay to the audio."""
config = config or PreprocessingConfig()
if delay is None: if delay is None:
delay = np.random.uniform(0, max_delay) delay = np.random.uniform(0, max_delay)
if alpha is None: if weight is None:
alpha = np.random.uniform(min_alpha, max_alpha) weight = np.random.uniform(min_weight, max_weight)
spec = train_example["spectrogram"] audio = example["audio"]
step = arrays.get_dim_step(audio, "audio_time")
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
audio = audio + weight * audio_delay
time_coords = spec.coords["time"] spectrogram = compute_spectrogram(
start_time = time_coords.attrs["min"] audio.rename({"audio_time": "time"}),
end_time = time_coords.attrs["max"] config=config.spectrogram,
step = (end_time - start_time) / time_coords.size ).data
spec_delay = spec.shift(time=int(delay / step), fill_value=0) # NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
spectrogram = adjust_width(
spectrogram,
example["spectrogram"].sizes["time"],
)
return train_example.assign(spectrogram=spec + alpha * spec_delay) return example.assign(
audio=audio,
spectrogram=xr.DataArray(
data=spectrogram,
dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords,
),
)
class VolumeAugmentationConfig(BaseAugmentationConfig):
min_scaling: float = 0.0
max_scaling: float = 2.0
def scale_volume( def scale_volume(
train_example: xr.Dataset, example: xr.Dataset,
factor: Optional[float] = None, factor: Optional[float] = None,
max_scaling: float = 2, max_scaling: float = 2,
min_scaling: float = 0, min_scaling: float = 0,
@ -139,106 +217,227 @@ def scale_volume(
if factor is None: if factor is None:
factor = np.random.uniform(min_scaling, max_scaling) factor = np.random.uniform(min_scaling, max_scaling)
return train_example.assign( return example.assign(spectrogram=example["spectrogram"] * factor)
spectrogram=train_example["spectrogram"] * factor
)
class WarpAugmentationConfig(BaseAugmentationConfig):
delta: float = 0.04
def warp_spectrogram( def warp_spectrogram(
train_example: xr.Dataset, example: xr.Dataset,
factor: Optional[float] = None, factor: Optional[float] = None,
delta: float = STRETCH_SQUEEZE_DELTA, delta: float = 0.04,
) -> xr.Dataset: ) -> xr.Dataset:
"""Warp a spectrogram.""" """Warp a spectrogram."""
if factor is None: if factor is None:
factor = np.random.uniform(1 - delta, 1 + delta) factor = np.random.uniform(1 - delta, 1 + delta)
time_coords = train_example.coords["time"] start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
start_time = time_coords.attrs["min"]
end_time = time_coords.attrs["max"]
duration = end_time - start_time duration = end_time - start_time
new_time = np.linspace( new_time = np.linspace(
start_time, start_time,
start_time + duration * factor, start_time + duration * factor,
train_example.time.size, example.time.size,
) )
return train_example.interp(time=new_time) spectrogram = (
example["spectrogram"]
.interp(
coords={"time": new_time},
method="linear",
kwargs=dict(
fill_value=0,
),
)
.clip(min=0)
)
detection = example["detection"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
classification = example["class"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
size = example["size"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
return example.assign(
{
"time": new_time,
"spectrogram": spectrogram,
"detection": detection,
"class": classification,
"size": size,
}
)
def mask_axis( def mask_axis(
train_example: xr.Dataset, array: xr.DataArray,
dim: str, dim: str,
start: float, start: float,
end: float, end: float,
mask_all: bool = False, mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
mask_value: float = 0, ) -> xr.DataArray:
) -> xr.Dataset: if dim not in array.dims:
if dim not in train_example.dims:
raise ValueError(f"Axis {dim} not found in array") raise ValueError(f"Axis {dim} not found in array")
coord = train_example.coords[dim] coord = array.coords[dim]
condition = (coord < start) | (coord > end) condition = (coord < start) | (coord > end)
if mask_all: if callable(mask_value):
return train_example.where(condition, other=mask_value) mask_value = mask_value(array)
return train_example.assign( return array.where(condition, other=mask_value)
spectrogram=train_example.spectrogram.where(
condition, other=mask_value
) class TimeMaskAugmentationConfig(BaseAugmentationConfig):
) max_perc: float = 0.05
max_masks: int = 3
def mask_time( def mask_time(
train_example: xr.Dataset, example: xr.Dataset,
max_time_mask: float = MASK_MAX_TIME_PERC, max_perc: float = 0.05,
max_num_masks: int = 3, max_mask: int = 3,
) -> xr.Dataset: ) -> xr.Dataset:
"""Mask a random section of the time axis.""" """Mask a random section of the time axis."""
num_masks = np.random.randint(1, max_mask + 1)
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
num_masks = np.random.randint(1, max_num_masks + 1) spectrogram = example["spectrogram"]
time_coord = train_example.coords["time"]
start_time = time_coord.attrs.get("min", time_coord.min())
end_time = time_coord.attrs.get("max", time_coord.max())
for _ in range(num_masks): for _ in range(num_masks):
mask_size = np.random.uniform(0, max_time_mask) mask_size = np.random.uniform(0, max_perc) * (end_time - start_time)
start = np.random.uniform(start_time, end_time - mask_size) start = np.random.uniform(start_time, end_time - mask_size)
end = start + mask_size end = start + mask_size
train_example = mask_axis(train_example, "time", start, end) spectrogram = mask_axis(spectrogram, "time", start, end)
return train_example return example.assign(spectrogram=spectrogram)
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
max_perc: float = 0.10
max_masks: int = 3
def mask_frequency( def mask_frequency(
train_example: xr.Dataset, example: xr.Dataset,
max_freq_mask: float = MASK_MAX_FREQ_PERC, max_perc: float = 0.10,
max_num_masks: int = 3, max_masks: int = 3,
) -> xr.Dataset: ) -> xr.Dataset:
"""Mask a random section of the frequency axis.""" """Mask a random section of the frequency axis."""
num_masks = np.random.randint(1, max_masks + 1)
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
num_masks = np.random.randint(1, max_num_masks + 1) spectrogram = example["spectrogram"]
freq_coord = train_example.coords["frequency"]
min_freq = freq_coord.min()
max_freq = freq_coord.max()
for _ in range(num_masks): for _ in range(num_masks):
mask_size = np.random.uniform(0, max_freq_mask) mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq)
start = np.random.uniform(min_freq, max_freq - mask_size) start = np.random.uniform(min_freq, max_freq - mask_size)
end = start + mask_size end = start + mask_size
train_example = mask_axis(train_example, "frequency", start, end) spectrogram = mask_axis(spectrogram, "frequency", start, end)
return train_example return example.assign(spectrogram=spectrogram)
AUGMENTATIONS: List[Augmentation] = [ class AugmentationsConfig(BaseConfig):
select_random_subclip, mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
add_echo, echo: EchoAugmentationConfig = Field(
scale_volume, default_factory=EchoAugmentationConfig
mask_time, )
mask_frequency, volume: VolumeAugmentationConfig = Field(
] default_factory=VolumeAugmentationConfig
)
warp: WarpAugmentationConfig = Field(
default_factory=WarpAugmentationConfig
)
time_mask: TimeMaskAugmentationConfig = Field(
default_factory=TimeMaskAugmentationConfig
)
frequency_mask: FrequencyMaskAugmentationConfig = Field(
default_factory=FrequencyMaskAugmentationConfig
)
def load_agumentation_config(
path: data.PathLike, field: Optional[str] = None
) -> AugmentationsConfig:
return load_config(path, schema=AugmentationsConfig, field=field)
def should_apply(config: BaseAugmentationConfig) -> bool:
if not config.enable:
return False
return np.random.uniform() < config.probability
def augment_example(
example: xr.Dataset,
config: AugmentationsConfig,
preprocessing_config: Optional[PreprocessingConfig] = None,
others: Optional[Callable[[], xr.Dataset]] = None,
) -> xr.Dataset:
if should_apply(config.mix) and (others is not None):
other = others()
example = mix_examples(
example,
other,
min_weight=config.mix.min_weight,
max_weight=config.mix.max_weight,
config=preprocessing_config,
)
if should_apply(config.echo):
example = add_echo(
example,
max_delay=config.echo.max_delay,
min_weight=config.echo.min_weight,
max_weight=config.echo.max_weight,
config=preprocessing_config,
)
if should_apply(config.volume):
example = scale_volume(
example,
max_scaling=config.volume.max_scaling,
min_scaling=config.volume.min_scaling,
)
if should_apply(config.warp):
example = warp_spectrogram(
example,
delta=config.warp.delta,
)
if should_apply(config.time_mask):
example = mask_time(
example,
max_perc=config.time_mask.max_perc,
max_mask=config.time_mask.max_masks,
)
if should_apply(config.frequency_mask):
example = mask_frequency(
example,
max_perc=config.frequency_mask.max_perc,
max_masks=config.frequency_mask.max_masks,
)
return example

View File

@ -0,0 +1,31 @@
from typing import Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.losses import LossConfig
__all__ = [
"OptimizerConfig",
"TrainingConfig",
"load_train_config",
]
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainingConfig(BaseConfig):
batch_size: int = 32
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
def load_train_config(
path: PathLike,
field: Optional[str] = None,
) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field)

View File

@ -1,12 +1,21 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union from typing import NamedTuple, Optional, Sequence, Union
import numpy as np
import torch import torch
import xarray as xr import xarray as xr
from pydantic import Field
from soundevent import data from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.tensors import adjust_width
from batdetect2.train.augmentations import (
AugmentationsConfig,
augment_example,
select_subclip,
)
from batdetect2.train.preprocess import PreprocessingConfig from batdetect2.train.preprocess import PreprocessingConfig
__all__ = [ __all__ = [
@ -26,74 +35,113 @@ class TrainExample(NamedTuple):
idx: torch.Tensor idx: torch.Tensor
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]: class SubclipConfig(BaseConfig):
return list(Path(directory).glob(f"*{extension}")) duration: Optional[float] = None
width: int = 512
random: bool = False
class DatasetConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
augmentation: AugmentationsConfig = Field(
default_factory=AugmentationsConfig
)
class LabeledDataset(Dataset): class LabeledDataset(Dataset):
def __init__( def __init__(
self, self,
filenames: Sequence[PathLike], filenames: Sequence[PathLike],
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None, subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
): ):
self.filenames = filenames self.filenames = filenames
self.transform = transform self.subclip = subclip
self.augmentation = augmentation
self.preprocessing = preprocessing or PreprocessingConfig()
def __len__(self): def __len__(self):
return len(self.filenames) return len(self.filenames)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
data = self.load(idx) dataset = self.get_dataset(idx)
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
if self.augmentation:
dataset = augment_example(
dataset,
self.augmentation,
preprocessing_config=self.preprocessing,
others=self.get_random_example,
)
return TrainExample( return TrainExample(
spec=data["spectrogram"], spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
detection_heatmap=data["detection"], detection_heatmap=self.to_tensor(dataset["detection"]),
class_heatmap=data["class"], class_heatmap=self.to_tensor(dataset["class"]),
size_heatmap=data["size"], size_heatmap=self.to_tensor(dataset["size"]),
idx=torch.tensor(idx), idx=torch.tensor(idx),
) )
@classmethod @classmethod
def from_directory(cls, directory: PathLike, extension: str = ".nc"): def from_directory(
return cls(get_files(directory, extension)) cls,
directory: PathLike,
extension: str = ".nc",
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
):
return cls(
get_files(directory, extension),
subclip=subclip,
augmentation=augmentation,
preprocessing=preprocessing,
)
def load(self, idx) -> Dict[str, torch.Tensor]: def get_random_example(self) -> xr.Dataset:
idx = np.random.randint(0, len(self))
dataset = self.get_dataset(idx) dataset = self.get_dataset(idx)
return {
"spectrogram": torch.tensor(
dataset["spectrogram"].values
).unsqueeze(0),
"detection": torch.tensor(dataset["detection"].values),
"class": torch.tensor(dataset["class"].values),
"size": torch.tensor(dataset["size"].values),
}
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset: if self.subclip:
if self.transform is not None: dataset = select_subclip(
return self.transform(dataset) dataset,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
return dataset return dataset
def get_dataset(self, idx): def get_dataset(self, idx) -> xr.Dataset:
return xr.open_dataset(self.filenames[idx]) return xr.open_dataset(self.filenames[idx])
def get_spectrogram(self, idx): def get_clip_annotation(self, idx) -> data.ClipAnnotation:
return xr.open_dataset(self.filenames[idx])["spectrogram"] return data.ClipAnnotation.model_validate_json(
self.get_dataset(idx).attrs["clip_annotation"]
)
def get_detection_mask(self, idx): def to_tensor(
return xr.open_dataset(self.filenames[idx])["detection"] self,
array: xr.DataArray,
dtype=np.float32,
) -> torch.Tensor:
tensor = torch.tensor(array.values.astype(dtype))
def get_class_mask(self, idx): if not self.subclip:
return xr.open_dataset(self.filenames[idx])["class"] return tensor
def get_size_mask(self, idx): width = self.subclip.width
return xr.open_dataset(self.filenames[idx])["size"] return adjust_width(tensor, width)
def get_clip_annotation(self, idx):
filename = self.filenames[idx]
dataset = xr.open_dataset(filename)
clip_annotation = dataset.attrs["clip_annotation"]
return data.ClipAnnotation.model_validate_json(clip_annotation)
def get_preprocessing_configuration(self, idx): def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"] return list(Path(directory).glob(f"*{extension}"))
return PreprocessingConfig.model_validate_json(config)

View File

@ -1,27 +1,43 @@
from typing import Tuple from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from pydantic import Field
from scipy.ndimage import gaussian_filter from scipy.ndimage import gaussian_filter
from soundevent import data, geometry, arrays from soundevent import arrays, data, geometry
from soundevent.geometry.operations import Positions from soundevent.geometry.operations import Positions
from soundevent.types import ClassMapper
from batdetect2.configs import BaseConfig, load_config
__all__ = [ __all__ = [
"ClassMapper", "HeatmapsConfig",
"LabelConfig",
"generate_heatmaps", "generate_heatmaps",
"load_label_config",
] ]
TARGET_SIGMA = 3.0 class HeatmapsConfig(BaseConfig):
position: Positions = "bottom-left"
sigma: float = 3.0
time_scale: float = 1000.0
frequency_scale: float = 1 / 859.375
class LabelConfig(BaseConfig):
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
def generate_heatmaps( def generate_heatmaps(
clip_annotation: data.ClipAnnotation, sound_events: Sequence[data.SoundEventAnnotation],
spec: xr.DataArray, spec: xr.DataArray,
class_mapper: ClassMapper, class_names: List[str],
target_sigma: float = TARGET_SIGMA, encoder: Callable[[Iterable[data.Tag]], Optional[str]],
target_sigma: float = 3.0,
position: Positions = "bottom-left", position: Positions = "bottom-left",
time_scale: float = 1000.0,
frequency_scale: float = 1 / 859.375,
dtype=np.float32, dtype=np.float32,
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]: ) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
shape = dict(zip(spec.dims, spec.shape)) shape = dict(zip(spec.dims, spec.shape))
@ -31,20 +47,13 @@ def generate_heatmaps(
"Spectrogram must have time and frequency dimensions." "Spectrogram must have time and frequency dimensions."
) )
time_duration = arrays.get_dim_width(spec, dim="time")
freq_bandwidth = arrays.get_dim_width(spec, dim="frequency")
# Compute the size factors
time_scale = 1 / time_duration
frequency_scale = 1 / freq_bandwidth
# Initialize heatmaps # Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype) detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray( class_heatmap = xr.DataArray(
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype), data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
dims=["category", *spec.dims], dims=["category", *spec.dims],
coords={ coords={
"category": class_mapper.class_labels, "category": [*class_names],
**spec.coords, **spec.coords,
}, },
) )
@ -57,9 +66,8 @@ def generate_heatmaps(
}, },
) )
for sound_event_annotation in clip_annotation.sound_events: for sound_event_annotation in sound_events:
geom = sound_event_annotation.sound_event.geometry geom = sound_event_annotation.sound_event.geometry
if geom is None: if geom is None:
continue continue
@ -67,23 +75,29 @@ def generate_heatmaps(
time, frequency = geometry.get_geometry_point(geom, position=position) time, frequency = geometry.get_geometry_point(geom, position=position)
# Set 1.0 at the position of the sound event in the detection heatmap # Set 1.0 at the position of the sound event in the detection heatmap
detection_heatmap = arrays.set_value_at_pos( try:
detection_heatmap, detection_heatmap = arrays.set_value_at_pos(
1.0, detection_heatmap,
time=time, 1.0,
frequency=frequency, time=time,
) frequency=frequency,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
continue
# Set the size of the sound event at the position in the size heatmap # Set the size of the sound event at the position in the size heatmap
start_time, low_freq, end_time, high_freq = geometry.compute_bounds( start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom geom
) )
size = np.array( size = np.array(
[ [
(end_time - start_time) * time_scale, (end_time - start_time) * time_scale,
(high_freq - low_freq) * frequency_scale, (high_freq - low_freq) * frequency_scale,
] ]
) )
size_heatmap = arrays.set_value_at_pos( size_heatmap = arrays.set_value_at_pos(
size_heatmap, size_heatmap,
size, size,
@ -92,14 +106,12 @@ def generate_heatmaps(
) )
# Get the class name of the sound event # Get the class name of the sound event
class_name = class_mapper.transform(sound_event_annotation) class_name = encoder(sound_event_annotation.tags)
if class_name is None: if class_name is None:
# If the label is None skip the sound event # If the label is None skip the sound event
continue continue
# Set 1.0 at the position and category of the sound event in the class
# heatmap
class_heatmap = arrays.set_value_at_pos( class_heatmap = arrays.set_value_at_pos(
class_heatmap, class_heatmap,
1.0, 1.0,
@ -130,3 +142,9 @@ def generate_heatmaps(
).fillna(0.0) ).fillna(0.0)
return detection_heatmap, class_heatmap, size_heatmap return detection_heatmap, class_heatmap, size_heatmap
def load_label_config(
path: data.PathLike, field: Optional[str] = None
) -> LabelConfig:
return load_config(path, schema=LabelConfig, field=field)

View File

@ -0,0 +1,82 @@
from typing import Callable, NamedTuple, Optional
import torch
from soundevent import data
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.models.typing import DetectionModel
from batdetect2.train.dataset import LabeledDataset
class TrainInputs(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def train_loop(
model: DetectionModel,
train_dataset: LabeledDataset[TrainInputs],
validation_dataset: LabeledDataset[TrainInputs],
device: Optional[torch.device] = None,
num_epochs: int = 100,
learning_rate: float = 1e-4,
):
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32)
model.to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(
optimizer,
num_epochs * len(train_loader),
)
for epoch in range(num_epochs):
train_loss = train_single_epoch(
model,
train_loader,
optimizer,
device,
scheduler,
)
def train_single_epoch(
model: DetectionModel,
train_loader: DataLoader,
optimizer: Adam,
device: torch.device,
scheduler: CosineAnnealingLR,
):
model.train()
train_loss = tu.AverageMeter()
for batch in train_loader:
optimizer.zero_grad()
spec = batch.spec.to(device)
detection_heatmap = batch.detection_heatmap.to(device)
class_heatmap = batch.class_heatmap.to(device)
size_heatmap = batch.size_heatmap.to(device)
outputs = model(spec)
loss = loss_fun(
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
class_inv_freq,
)
train_loss.update(loss.item(), data.shape[0])
loss.backward()
optimizer.step()
scheduler.step()

View File

@ -16,7 +16,6 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
def split_diff(ann_dir, wav_dir, load_extra=True): def split_diff(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append( train_sets.append(
@ -144,7 +143,6 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
def split_same(ann_dir, wav_dir, load_extra=True): def split_same(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append( train_sets.append(

View File

@ -69,7 +69,8 @@ def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
def standardize_low_freq( def standardize_low_freq(
data: List[types.FileAnnotation], class_of_interest: str, data: List[types.FileAnnotation],
class_of_interest: str,
) -> List[types.FileAnnotation]: ) -> List[types.FileAnnotation]:
# address the issue of highly variable low frequency annotations # address the issue of highly variable low frequency annotations
# this often happens for contstant frequency calls # this often happens for contstant frequency calls

View File

@ -1,56 +0,0 @@
import pytorch_lightning as L
from torch import Tensor, optim
from batdetect2.models.typing import DetectionModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
__all__ = [
"LitDetectorModel",
]
class LitDetectorModel(L.LightningModule):
model: DetectionModel
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
super().__init__()
self.model = model
self.learning_rate = learning_rate
def compute_loss(
self,
outputs: ModelOutput,
batch: TrainExample,
) -> Tensor:
detection_loss = losses.focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
)
size_loss = losses.bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = losses.focal_loss(
outputs.class_probs,
batch.class_heatmap,
valid_mask=valid_mask,
)
return detection_loss + size_loss + classification_loss
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
outputs: ModelOutput = self.model(batch.spec)
loss = self.compute_loss(outputs, batch)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
return [optimizer], [scheduler]

View File

@ -1,7 +1,23 @@
from typing import Optional from typing import NamedTuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.models.typing import ModelOutput
from batdetect2.train.dataset import TrainExample
__all__ = [
"bbox_size_loss",
"compute_loss",
"focal_loss",
"mse_loss",
]
class SizeLossConfig(BaseConfig):
weight: float = 0.1
def bbox_size_loss( def bbox_size_loss(
@ -17,6 +33,11 @@ def bbox_size_loss(
) )
class FocalLossConfig(BaseConfig):
beta: float = 4
alpha: float = 2
def focal_loss( def focal_loss(
pred: torch.Tensor, pred: torch.Tensor,
gt: torch.Tensor, gt: torch.Tensor,
@ -44,7 +65,7 @@ def focal_loss(
) )
if weights is not None: if weights is not None:
pos_loss = pos_loss * weights pos_loss = pos_loss * torch.tensor(weights)
# neg_loss = neg_loss*weights # neg_loss = neg_loss*weights
if valid_mask is not None: if valid_mask is not None:
@ -75,3 +96,71 @@ def mse_loss(
else: else:
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum() op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
return op return op
class DetectionLossConfig(BaseConfig):
weight: float = 1.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class ClassificationLossConfig(BaseConfig):
weight: float = 2.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class_weights: Optional[list[float]] = None
class LossConfig(BaseConfig):
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
classification: ClassificationLossConfig = Field(
default_factory=ClassificationLossConfig
)
class Losses(NamedTuple):
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
def compute_loss(
batch: TrainExample,
outputs: ModelOutput,
conf: LossConfig,
class_weights: Optional[torch.Tensor] = None,
) -> Losses:
detection_loss = focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
beta=conf.detection.focal.beta,
alpha=conf.detection.focal.alpha,
)
size_loss = bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = focal_loss(
outputs.class_probs,
batch.class_heatmap,
weights=class_weights,
valid_mask=valid_mask,
beta=conf.classification.focal.beta,
alpha=conf.classification.focal.alpha,
)
total = (
detection_loss * conf.detection.weight
+ size_loss * conf.size.weight
+ classification_loss * conf.classification.weight
)
return Losses(
detection=detection_loss,
size=size_loss,
classification=classification_loss,
total=total,
)

View File

@ -1,20 +1,28 @@
"""Module for preprocessing data for training.""" """Module for preprocessing data for training."""
import os import os
import warnings
from functools import partial from functools import partial
from multiprocessing import Pool
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Sequence, Union from typing import Callable, Optional, Sequence, Union
from tqdm.auto import tqdm
from multiprocessing import Pool
import xarray as xr import xarray as xr
from pydantic import Field
from soundevent import data from soundevent import data
from tqdm.auto import tqdm
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps from batdetect2.configs import BaseConfig
from batdetect2.data.preprocessing import ( from batdetect2.preprocess import (
preprocess_audio_clip,
PreprocessingConfig, PreprocessingConfig,
compute_spectrogram,
load_clip_audio,
)
from batdetect2.train.labels import LabelConfig, generate_heatmaps
from batdetect2.train.targets import (
TargetConfig,
build_encoder,
build_sound_event_filter,
get_class_names,
) )
PathLike = Union[Path, str, os.PathLike] PathLike = Union[Path, str, os.PathLike]
@ -22,31 +30,76 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
__all__ = [ __all__ = [
"preprocess_annotations", "preprocess_annotations",
"preprocess_single_annotation",
"generate_train_example",
"TrainPreprocessingConfig",
] ]
class TrainPreprocessingConfig(BaseConfig):
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
target: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def generate_train_example( def generate_train_example(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
class_mapper: ClassMapper, preprocessing_config: Optional[PreprocessingConfig] = None,
preprocessing_config: PreprocessingConfig = PreprocessingConfig(), target_config: Optional[TargetConfig] = None,
target_sigma: float = TARGET_SIGMA, label_config: Optional[LabelConfig] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Generate a training example.""" """Generate a training example."""
spectrogram = preprocess_audio_clip( config = TrainPreprocessingConfig(
clip_annotation.clip, preprocessing=preprocessing_config or PreprocessingConfig(),
config=preprocessing_config, target=target_config or TargetConfig(),
labels=label_config or LabelConfig(),
) )
wave = load_clip_audio(
clip_annotation.clip,
config=config.preprocessing.audio,
)
spectrogram = compute_spectrogram(
wave,
config=config.preprocessing.spectrogram,
)
filter_fn = build_sound_event_filter(
include=config.target.include,
exclude=config.target.exclude,
)
selected_events = [
event for event in clip_annotation.sound_events if filter_fn(event)
]
encoder = build_encoder(
config.target.classes,
replacement_rules=config.target.replace,
)
class_names = get_class_names(config.target.classes)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation, selected_events,
spectrogram, spectrogram,
class_mapper, class_names,
target_sigma=target_sigma, encoder,
target_sigma=config.labels.heatmaps.sigma,
position=config.labels.heatmaps.position,
time_scale=config.labels.heatmaps.time_scale,
frequency_scale=config.labels.heatmaps.frequency_scale,
) )
dataset = xr.Dataset( dataset = xr.Dataset(
{ {
# NOTE: Need to rename the time dimension to avoid conflicts with
# the spectrogram time dimension, otherwise xarray will interpolate
# the spectrogram and the heatmaps to the same temporal resolution
# as the waveform.
"audio": wave.rename({"time": "audio_time"}),
"spectrogram": spectrogram, "spectrogram": spectrogram,
"detection": detection_heatmap, "detection": detection_heatmap,
"class": class_heatmap, "class": class_heatmap,
@ -56,9 +109,12 @@ def generate_train_example(
return dataset.assign_attrs( return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}", title=f"Training example for {clip_annotation.uuid}",
preprocessing_configuration=preprocessing_config.model_dump_json(), config=config.model_dump_json(),
target_sigma=target_sigma, clip_annotation=clip_annotation.model_dump_json(
clip_annotation=clip_annotation.model_dump_json(), exclude_none=True,
exclude_defaults=True,
exclude_unset=True,
),
) )
@ -77,36 +133,54 @@ def save_to_file(
) )
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
"""Load configuration from file."""
path = Path(path)
if not path.is_file():
warnings.warn(f"Config file not found: {path}. Using default config.")
return PreprocessingConfig(**kwargs)
try:
return PreprocessingConfig.model_validate_json(path.read_text())
except ValueError as e:
warnings.warn(
f"Failed to load config file: {e}. Using default config."
)
return PreprocessingConfig(**kwargs)
def _get_filename(clip_annotation: data.ClipAnnotation) -> str: def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
return f"{clip_annotation.uuid}.nc" return f"{clip_annotation.uuid}.nc"
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: PathLike,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess annotations and save to disk."""
output_dir = Path(output_dir)
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
with Pool(max_workers) as pool:
list(
tqdm(
pool.imap_unordered(
partial(
preprocess_single_annotation,
output_dir=output_dir,
filename_fn=filename_fn,
replace=replace,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
),
clip_annotations,
),
total=len(clip_annotations),
)
)
def preprocess_single_annotation( def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
output_dir: PathLike, output_dir: PathLike,
config: PreprocessingConfig, preprocessing_config: Optional[PreprocessingConfig] = None,
class_mapper: ClassMapper, target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
filename_fn: FilenameFn = _get_filename, filename_fn: FilenameFn = _get_filename,
replace: bool = False, replace: bool = False,
target_sigma: float = TARGET_SIGMA,
) -> None: ) -> None:
output_dir = Path(output_dir) output_dir = Path(output_dir)
@ -119,50 +193,16 @@ def preprocess_single_annotation(
if path.is_file() and replace: if path.is_file() and replace:
path.unlink() path.unlink()
sample = generate_train_example( try:
clip_annotation, sample = generate_train_example(
class_mapper, clip_annotation,
preprocessing_config=config, preprocessing_config=preprocessing_config,
target_sigma=target_sigma, target_config=target_config,
) label_config=label_config,
)
except Exception as error:
raise RuntimeError(
f"Failed to process annotation: {clip_annotation.uuid}"
) from error
save_to_file(sample, path) save_to_file(sample, path)
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: PathLike,
class_mapper: ClassMapper,
target_sigma: float = TARGET_SIGMA,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
config: Optional[PreprocessingConfig] = None,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess annotations and save to disk."""
output_dir = Path(output_dir)
if config is None:
config = PreprocessingConfig()
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
with Pool(max_workers) as pool:
list(
tqdm(
pool.imap_unordered(
partial(
preprocess_single_annotation,
output_dir=output_dir,
config=config,
class_mapper=class_mapper,
filename_fn=filename_fn,
replace=replace,
target_sigma=target_sigma,
),
clip_annotations,
),
total=len(clip_annotations),
)
)

181
batdetect2/train/targets.py Normal file
View File

@ -0,0 +1,181 @@
from collections.abc import Iterable
from functools import partial
from pathlib import Path
from typing import Callable, List, Optional, Set
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.terms import TagInfo, get_tag_from_info
__all__ = [
"TargetConfig",
"load_target_config",
"build_encoder",
"build_decoder",
"filter_sound_event",
]
class ReplaceConfig(BaseConfig):
"""Configuration for replacing tags."""
original: TagInfo
replacement: TagInfo
class TargetConfig(BaseConfig):
"""Configuration for target generation."""
classes: List[TagInfo] = Field(
default_factory=lambda: [
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
]
)
generic_class: Optional[TagInfo] = Field(
default_factory=lambda: TagInfo(key="class", value="Bat")
)
include: Optional[List[TagInfo]] = Field(
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
)
exclude: Optional[List[TagInfo]] = Field(
default_factory=lambda: [
TagInfo(key="class", value=""),
TagInfo(key="class", value=" "),
TagInfo(key="class", value="Unknown"),
]
)
replace: Optional[List[ReplaceConfig]] = None
def build_sound_event_filter(
include: Optional[List[TagInfo]] = None,
exclude: Optional[List[TagInfo]] = None,
) -> Callable[[data.SoundEventAnnotation], bool]:
include_tags = (
{get_tag_from_info(tag) for tag in include} if include else None
)
exclude_tags = (
{get_tag_from_info(tag) for tag in exclude} if exclude else None
)
return partial(
filter_sound_event,
include=include_tags,
exclude=exclude_tags,
)
def get_tag_label(tag_info: TagInfo) -> str:
return tag_info.label if tag_info.label else tag_info.value
def get_class_names(classes: List[TagInfo]) -> List[str]:
return sorted({get_tag_label(tag) for tag in classes})
def build_replacer(
rules: List[ReplaceConfig],
) -> Callable[[data.Tag], data.Tag]:
mapping = {
get_tag_from_info(rule.original): get_tag_from_info(rule.replacement)
for rule in rules
}
def replacer(tag: data.Tag) -> data.Tag:
return mapping.get(tag, tag)
return replacer
def build_encoder(
classes: List[TagInfo],
replacement_rules: Optional[List[ReplaceConfig]] = None,
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
target_tags = set([get_tag_from_info(tag) for tag in classes])
tag_mapping = {
tag: get_tag_label(tag_info)
for tag, tag_info in zip(target_tags, classes)
}
replacer = (
build_replacer(replacement_rules) if replacement_rules else lambda x: x
)
def encoder(
tags: Iterable[data.Tag],
) -> Optional[str]:
sanitized_tags = {replacer(tag) for tag in tags}
intersection = sanitized_tags & target_tags
if not intersection:
return None
first = intersection.pop()
return tag_mapping[first]
return encoder
def build_decoder(
classes: List[TagInfo],
) -> Callable[[str], List[data.Tag]]:
target_tags = set([get_tag_from_info(tag) for tag in classes])
tag_mapping = {
get_tag_label(tag_info): tag
for tag, tag_info in zip(target_tags, classes)
}
def decoder(label: str) -> List[data.Tag]:
tag = tag_mapping.get(label)
return [tag] if tag else []
return decoder
def filter_sound_event(
sound_event_annotation: data.SoundEventAnnotation,
include: Optional[Set[data.Tag]] = None,
exclude: Optional[Set[data.Tag]] = None,
) -> bool:
tags = set(sound_event_annotation.tags)
if include is not None and not tags & include:
return False
if exclude is not None and tags & exclude:
return False
return True
def load_target_config(
path: Path, field: Optional[str] = None
) -> TargetConfig:
return load_config(path, schema=TargetConfig, field=field)
DEFAULT_SPECIES_LIST = [
"Barbastellus barbastellus",
"Eptesicus serotinus",
"Myotis alcathoe",
"Myotis bechsteinii",
"Myotis brandtii",
"Myotis daubentonii",
"Myotis mystacinus",
"Myotis nattereri",
"Nyctalus leisleri",
"Nyctalus noctula",
"Pipistrellus nathusii",
"Pipistrellus pipistrellus",
"Pipistrellus pygmaeus",
"Plecotus auritus",
"Plecotus austriacus",
"Rhinolophus ferrumequinum",
"Rhinolophus hipposideros",
]

View File

@ -1,82 +1,68 @@
from typing import Callable, NamedTuple, Optional from typing import Optional, Union
import torch from lightning import LightningModule
from soundevent import data from lightning.pytorch import Trainer
from torch.optim import Adam from soundevent.data import PathLike
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.data.datasets import ClipAnnotationDataset from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.typing import DetectionModel from batdetect2.train.dataset import LabeledDataset
__all__ = [
"train",
"TrainerConfig",
"load_trainer_config",
]
class TrainInputs(NamedTuple): class TrainerConfig(BaseConfig):
spec: torch.Tensor accelerator: str = "auto"
detection_heatmap: torch.Tensor accumulate_grad_batches: int = 1
class_heatmap: torch.Tensor deterministic: bool = True
size_heatmap: torch.Tensor check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = None
min_epochs: Optional[int] = 100
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
reload_dataloaders_every_n_epochs: Optional[int] = None
val_check_interval: Optional[Union[int, float]] = None
def train_loop( def load_trainer_config(path: PathLike, field: Optional[str] = None):
model: DetectionModel, return load_config(path, schema=TrainerConfig, field=field)
train_dataset: ClipAnnotationDataset[TrainInputs],
validation_dataset: ClipAnnotationDataset[TrainInputs],
device: Optional[torch.device] = None, def train(
num_epochs: int = 100, module: LightningModule,
learning_rate: float = 1e-4, train_dataset: LabeledDataset,
trainer_config: Optional[TrainerConfig] = None,
dev_run: bool = False,
overfit_batches: bool = False,
profiler: Optional[str] = None,
): ):
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) trainer_config = trainer_config or TrainerConfig()
validation_loader = DataLoader(validation_dataset, batch_size=32) trainer = Trainer(
**trainer_config.model_dump(
model.to(device) exclude_unset=True,
exclude_none=True,
optimizer = Adam(model.parameters(), lr=learning_rate) ),
scheduler = CosineAnnealingLR( fast_dev_run=dev_run,
optimizer, overfit_batches=overfit_batches,
num_epochs * len(train_loader), profiler=profiler,
) )
train_loader = DataLoader(
for epoch in range(num_epochs): train_dataset,
train_loss = train_single_epoch( batch_size=module.config.train.batch_size,
model, shuffle=True,
train_loader, num_workers=7,
optimizer, )
device, trainer.fit(module, train_dataloaders=train_loader)
scheduler,
)
def train_single_epoch(
model: DetectionModel,
train_loader: DataLoader,
optimizer: Adam,
device: torch.device,
scheduler: CosineAnnealingLR,
):
model.train()
train_loss = tu.AverageMeter()
for batch in train_loader:
optimizer.zero_grad()
spec = batch.spec.to(device)
detection_heatmap = batch.detection_heatmap.to(device)
class_heatmap = batch.class_heatmap.to(device)
size_heatmap = batch.size_heatmap.to(device)
outputs = model(spec)
loss = loss_fun(
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
class_inv_freq,
)
train_loss.update(loss.item(), data.shape[0])
loss.backward()
optimizer.step()
scheduler.step()

View File

@ -1,13 +1,12 @@
"""Types used in the code base.""" """Types used in the code base."""
from typing import Any, List, NamedTuple, Optional from typing import Any, List, NamedTuple, Optional
import numpy as np import numpy as np
import torch import torch
try: from typing import TypedDict
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
try: try:
@ -15,9 +14,8 @@ try:
except ImportError: except ImportError:
from typing_extensions import Protocol from typing_extensions import Protocol
try: try:
from typing import NotRequired from typing import NotRequired # type: ignore
except ImportError: except ImportError:
from typing_extensions import NotRequired from typing_extensions import NotRequired
@ -597,8 +595,7 @@ class FeatureExtractor(Protocol):
self, self,
prediction: Prediction, prediction: Prediction,
**kwargs: Any, **kwargs: Any,
) -> float: ) -> float: ...
...
class DatasetDict(TypedDict): class DatasetDict(TypedDict):

View File

@ -6,6 +6,8 @@ import librosa.core.spectrum
import numpy as np import numpy as np
import torch import torch
from batdetect2.detector import parameters
from . import wavfile from . import wavfile
__all__ = [ __all__ = [
@ -15,20 +17,44 @@ __all__ = [
] ]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): def time_to_x_coords(
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor time_in_file: float,
noverlap = np.floor(fft_overlap * nfft) samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap) window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
nfft = np.floor(window_duration * samplerate) # int() uses floor
noverlap = np.floor(window_overlap * nfft)
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process def x_coords_to_time(
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): x_pos: int,
nfft = np.floor(fft_win_length * sampling_rate) samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
noverlap = np.floor(fft_overlap * nfft) window_duration: float = parameters.FFT_WIN_LENGTH_S,
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
return ((x_pos * n_step) + n_overlap) / samplerate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def x_coord_to_sample(
x_pos: int,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
x_pos = int(x_pos / resize_factor)
return int((x_pos * n_step) + n_overlap)
def generate_spectrogram( def generate_spectrogram(
audio, audio,
sampling_rate, sampling_rate,
@ -64,7 +90,7 @@ def generate_spectrogram(
np.abs( np.abs(
np.hanning( np.hanning(
int(params["fft_win_length"] * sampling_rate) int(params["fft_win_length"] * sampling_rate)
) ).astype(np.float32)
) )
** 2 ** 2
).sum() ).sum()
@ -74,7 +100,7 @@ def generate_spectrogram(
# log_scaling = (1.0 / sampling_rate)*10e4 # log_scaling = (1.0 / sampling_rate)*10e4
spec = np.log1p(log_scaling * spec) spec = np.log1p(log_scaling * spec)
elif params["spec_scale"] == "pcen": elif params["spec_scale"] == "pcen":
spec = pcen(spec , sampling_rate) spec = pcen(spec, sampling_rate)
elif params["spec_scale"] == "none": elif params["spec_scale"] == "none":
pass pass
@ -194,55 +220,118 @@ def load_audio(
return sampling_rate, audio_raw return sampling_rate, audio_raw
def compute_spectrogram_width(
length: int,
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = int(window_duration * samplerate)
n_overlap = int(window_overlap * n_fft)
n_step = n_fft - n_overlap
width = (length - n_overlap) // n_step
return int(width * resize_factor)
def pad_audio( def pad_audio(
audio_raw, audio: np.ndarray,
fs, samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
ms, window_duration: float = parameters.FFT_WIN_LENGTH_S,
overlap_perc, window_overlap: float = parameters.FFT_OVERLAP,
resize_factor, resize_factor: float = parameters.RESIZE_FACTOR,
divide_factor, divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
fixed_width=None, fixed_width: Optional[int] = None,
): ):
# Adds zeros to the end of the raw data so that the generated sepctrogram """Pad audio to be evenly divisible by `divide_factor`.
# will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training
# This code could be clearer, clean up This function pads the audio signal with zeros to ensure that the
nfft = int(ms * fs) generated spectrogram length will be evenly divisible by `divide_factor`.
noverlap = int(overlap_perc * nfft) This is important for the model to work correctly.
step = nfft - noverlap
min_size = int(divide_factor * (1.0 / resize_factor))
spec_width = (audio_raw.shape[0] - noverlap) // step
spec_width_rs = spec_width * resize_factor
if fixed_width is not None and spec_width < fixed_width: This `divide_factor` comes from the model architecture as it downscales
# too small the spectrogram by this factor, so the input must be divisible by this
# used during training to ensure all the batches are the same size integer number.
diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack( Parameters
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype)) ----------
audio : np.ndarray
The audio signal.
samplerate : int
The sampling rate of the audio signal.
window_size : float
The window size in seconds used for the spectrogram computation.
window_overlap : float
The overlap between windows in the spectrogram computation.
resize_factor : float
This factor is used to resize the spectrogram after the STFT
computation. Default is 0.5 which means that the spectrogram will be
reduced by half. Important to take into account for the final size of
the spectrogram.
divide_factor : int
The factor by which the spectrogram will be divided.
fixed_width : int, optional
If provided, the audio will be padded or cut so that the resulting
spectrogram width will be equal to this value.
Returns
-------
np.ndarray
The padded audio signal.
"""
spec_width = compute_spectrogram_width(
audio.shape[0],
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
if fixed_width:
target_samples = x_coord_to_sample(
fixed_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
) )
elif fixed_width is not None and spec_width > fixed_width: if spec_width < fixed_width:
# too big # need to be at least min_size
# used during training to ensure all the batches are the same size diff = target_samples - audio.shape[0]
diff = fixed_width * step + noverlap - audio_raw.shape[0] return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
audio_raw = audio_raw[:diff]
elif ( if spec_width > fixed_width:
spec_width_rs < min_size return audio[:target_samples]
or (np.floor(spec_width_rs) % divide_factor) != 0
): return audio
# need to be at least min_size
div_amt = np.ceil(spec_width_rs / float(divide_factor)) min_width = int(divide_factor / resize_factor)
div_amt = np.maximum(1, div_amt)
target_size = int(div_amt * divide_factor * (1.0 / resize_factor)) if spec_width < min_width:
diff = target_size * step + noverlap - audio_raw.shape[0] target_samples = x_coord_to_sample(
audio_raw = np.hstack( min_width,
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype)) samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
) )
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
return audio_raw if (spec_width % divide_factor) == 0:
return audio
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
target_samples = x_coord_to_sample(
target_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
def gen_mag_spectrogram(x, fs, ms, overlap_perc): def gen_mag_spectrogram(x, fs, ms, overlap_perc):

View File

@ -8,6 +8,11 @@ import pandas as pd
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
try:
from numpy.exceptions import AxisError
except ImportError:
from numpy import AxisError # type: ignore
import batdetect2.detector.compute_features as feats import batdetect2.detector.compute_features as feats
import batdetect2.detector.post_process as pp import batdetect2.detector.post_process as pp
import batdetect2.utils.audio_utils as au import batdetect2.utils.audio_utils as au
@ -80,6 +85,7 @@ def load_model(
model_path: str = DEFAULT_MODEL_PATH, model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True, load_weights: bool = True,
device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None,
weights_only: bool = True,
) -> Tuple[DetectionModel, ModelParameters]: ) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
@ -100,7 +106,11 @@ def load_model(
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.") raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device) net_params = torch.load(
model_path,
map_location=device,
weights_only=weights_only,
)
params = net_params["params"] params = net_params["params"]
@ -242,7 +252,7 @@ def format_single_result(
) )
class_name = class_names[np.argmax(class_overall)] class_name = class_names[np.argmax(class_overall)]
annotations = get_annotations_from_preds(predictions, class_names) annotations = get_annotations_from_preds(predictions, class_names)
except (np.AxisError, ValueError): except (AxisError, ValueError):
# No detections # No detections
class_overall = np.zeros(len(class_names)) class_overall = np.zeros(len(class_names))
class_name = "None" class_name = "None"
@ -399,7 +409,7 @@ def save_results_to_file(results, op_path: str) -> None:
def compute_spectrogram( def compute_spectrogram(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: float, sampling_rate: int,
params: SpectrogramParameters, params: SpectrogramParameters,
device: torch.device, device: torch.device,
) -> Tuple[float, torch.Tensor]: ) -> Tuple[float, torch.Tensor]:
@ -617,7 +627,7 @@ def process_spectrogram(
def _process_audio_array( def _process_audio_array(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: float, sampling_rate: int,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
@ -738,9 +748,7 @@ def process_file(
# Get original sampling rate # Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file) file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * float( orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
config.get("time_expansion", 1.0) or 1.0
)
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio( sampling_rate, audio_full = au.load_audio(

View File

@ -417,7 +417,9 @@ def plot_confusion_matrix(
cm_norm = cm.sum(1) cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0] valid_inds = np.where(cm_norm > 0)[0]
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] cm[valid_inds, :] = (
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
)
cm[np.where(cm_norm == -0)[0], :] = np.nan cm[np.where(cm_norm == -0)[0], :] = np.nan
if verbose: if verbose:

View File

@ -155,9 +155,9 @@ class InteractivePlotter:
# draw bounding box around call # draw bounding box around call
self.ax[1].patches[0].remove() self.ax[1].patches[0].remove()
spec_width_orig = self.spec_slices[self.current_id].shape[1] / ( spec_width_orig = self.spec_slices[self.current_id].shape[
1.0 + 2.0 * self.spec_pad 1
) ] / (1.0 + 2.0 * self.spec_pad)
xx = w_diff + self.spec_pad * spec_width_orig xx = w_diff + self.spec_pad * spec_width_orig
ww = spec_width_orig ww = spec_width_orig
yy = self.call_info[self.current_id]["low_freq"] / 1000 yy = self.call_info[self.current_id]["low_freq"] / 1000
@ -183,7 +183,9 @@ class InteractivePlotter:
round(self.call_info[self.current_id]["start_time"], 3) round(self.call_info[self.current_id]["start_time"], 3)
) )
+ ", prob=" + ", prob="
+ str(round(self.call_info[self.current_id]["det_prob"], 3)) + str(
round(self.call_info[self.current_id]["det_prob"], 3)
)
) )
self.ax[0].set_xlabel(info_str) self.ax[0].set_xlabel(info_str)

View File

@ -8,6 +8,7 @@ Functions
`write`: Write a numpy array as a WAV file. `write`: Write a numpy array as a WAV file.
""" """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os import os
@ -156,7 +157,6 @@ def read(filename, mmap=False):
fid = open(filename, "rb") fid = open(filename, "rb")
try: try:
# some files seem to have the size recorded in the header greater than # some files seem to have the size recorded in the header greater than
# the actual file size. # the actual file size.
fid.seek(0, os.SEEK_END) fid.seek(0, os.SEEK_END)

File diff suppressed because one or more lines are too long

1166
notebooks/Migrations.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -6,11 +6,11 @@
"id": "cfb0b360-a204-4c27-a18f-3902e8758879", "id": "cfb0b360-a204-4c27-a18f-3902e8758879",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:20.598611Z", "iopub.execute_input": "2024-11-19T17:33:02.699871Z",
"iopub.status.busy": "2024-07-16T00:27:20.596274Z", "iopub.status.busy": "2024-11-19T17:33:02.699590Z",
"iopub.status.idle": "2024-07-16T00:27:20.670888Z", "iopub.status.idle": "2024-11-19T17:33:02.710312Z",
"shell.execute_reply": "2024-07-16T00:27:20.668193Z", "shell.execute_reply": "2024-11-19T17:33:02.709798Z",
"shell.execute_reply.started": "2024-07-16T00:27:20.598423Z" "shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
} }
}, },
"outputs": [], "outputs": [],
@ -25,11 +25,11 @@
"id": "326c5432-94e6-4abf-a332-fe902559461b", "id": "326c5432-94e6-4abf-a332-fe902559461b",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:20.676278Z", "iopub.execute_input": "2024-11-19T17:33:02.711324Z",
"iopub.status.busy": "2024-07-16T00:27:20.675545Z", "iopub.status.busy": "2024-11-19T17:33:02.711067Z",
"iopub.status.idle": "2024-07-16T00:27:25.872556Z", "iopub.status.idle": "2024-11-19T17:33:09.092380Z",
"shell.execute_reply": "2024-07-16T00:27:25.871725Z", "shell.execute_reply": "2024-11-19T17:33:09.091830Z",
"shell.execute_reply.started": "2024-07-16T00:27:20.676206Z" "shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
} }
}, },
"outputs": [ "outputs": [
@ -37,7 +37,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", "/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n" " from .autonotebook import tqdm as notebook_tqdm\n"
] ]
} }
@ -45,26 +45,35 @@
"source": [ "source": [
"from pathlib import Path\n", "from pathlib import Path\n",
"from typing import List, Optional\n", "from typing import List, Optional\n",
"import torch\n",
"\n", "\n",
"import pytorch_lightning as pl\n", "import pytorch_lightning as pl\n",
"from soundevent import data\n", "from batdetect2.train.modules import DetectorModel\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from batdetect2.data.labels import ClassMapper\n",
"from batdetect2.models.detectors import DetectorModel\n",
"from batdetect2.train.augmentations import (\n", "from batdetect2.train.augmentations import (\n",
" add_echo,\n", " add_echo,\n",
" select_random_subclip,\n", " select_random_subclip,\n",
" warp_spectrogram,\n", " warp_spectrogram,\n",
")\n", ")\n",
"from batdetect2.train.dataset import LabeledDataset, get_files\n", "from batdetect2.train.dataset import LabeledDataset, get_files\n",
"from batdetect2.train.preprocess import PreprocessingConfig" "from batdetect2.train.preprocess import PreprocessingConfig\n",
"from soundevent import data\n",
"import matplotlib.pyplot as plt\n",
"from soundevent.types import ClassMapper\n",
"from torch.utils.data import DataLoader"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "fa202af2-5c0d-4b5d-91a3-097ef5cd4272", "id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
"metadata": {}, "metadata": {
"execution": {
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
}
},
"source": [ "source": [
"## Training Datasets" "## Training Datasets"
] ]
@ -75,11 +84,11 @@
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61", "id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:25.874255Z", "iopub.execute_input": "2024-11-19T17:33:09.093487Z",
"iopub.status.busy": "2024-07-16T00:27:25.873473Z", "iopub.status.busy": "2024-11-19T17:33:09.092990Z",
"iopub.status.idle": "2024-07-16T00:27:25.912952Z", "iopub.status.idle": "2024-11-19T17:33:09.121636Z",
"shell.execute_reply": "2024-07-16T00:27:25.911844Z", "shell.execute_reply": "2024-11-19T17:33:09.121143Z",
"shell.execute_reply.started": "2024-07-16T00:27:25.874206Z" "shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
} }
}, },
"outputs": [], "outputs": [],
@ -93,11 +102,11 @@
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b", "id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:25.914456Z", "iopub.execute_input": "2024-11-19T17:33:09.122685Z",
"iopub.status.busy": "2024-07-16T00:27:25.914027Z", "iopub.status.busy": "2024-11-19T17:33:09.122270Z",
"iopub.status.idle": "2024-07-16T00:27:25.954939Z", "iopub.status.idle": "2024-11-19T17:33:09.151386Z",
"shell.execute_reply": "2024-07-16T00:27:25.953906Z", "shell.execute_reply": "2024-11-19T17:33:09.150788Z",
"shell.execute_reply.started": "2024-07-16T00:27:25.914410Z" "shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
} }
}, },
"outputs": [], "outputs": [],
@ -111,11 +120,11 @@
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288", "id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:25.956758Z", "iopub.execute_input": "2024-11-19T17:33:09.152327Z",
"iopub.status.busy": "2024-07-16T00:27:25.956260Z", "iopub.status.busy": "2024-11-19T17:33:09.152060Z",
"iopub.status.idle": "2024-07-16T00:27:25.997664Z", "iopub.status.idle": "2024-11-19T17:33:09.184041Z",
"shell.execute_reply": "2024-07-16T00:27:25.996074Z", "shell.execute_reply": "2024-11-19T17:33:09.183372Z",
"shell.execute_reply.started": "2024-07-16T00:27:25.956705Z" "shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
} }
}, },
"outputs": [], "outputs": [],
@ -129,11 +138,11 @@
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2", "id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:26.003195Z", "iopub.execute_input": "2024-11-19T17:33:09.186393Z",
"iopub.status.busy": "2024-07-16T00:27:26.002783Z", "iopub.status.busy": "2024-11-19T17:33:09.186117Z",
"iopub.status.idle": "2024-07-16T00:27:26.054400Z", "iopub.status.idle": "2024-11-19T17:33:09.220175Z",
"shell.execute_reply": "2024-07-16T00:27:26.053294Z", "shell.execute_reply": "2024-11-19T17:33:09.219322Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.003158Z" "shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
} }
}, },
"outputs": [], "outputs": [],
@ -152,11 +161,11 @@
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8", "id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:26.056060Z", "iopub.execute_input": "2024-11-19T17:33:09.221653Z",
"iopub.status.busy": "2024-07-16T00:27:26.055706Z", "iopub.status.busy": "2024-11-19T17:33:09.221242Z",
"iopub.status.idle": "2024-07-16T00:27:26.103227Z", "iopub.status.idle": "2024-11-19T17:33:09.260977Z",
"shell.execute_reply": "2024-07-16T00:27:26.102190Z", "shell.execute_reply": "2024-11-19T17:33:09.260375Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.056025Z" "shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
} }
}, },
"outputs": [], "outputs": [],
@ -168,7 +177,6 @@
" \"Myotis mystacinus\",\n", " \"Myotis mystacinus\",\n",
" \"Pipistrellus pipistrellus\",\n", " \"Pipistrellus pipistrellus\",\n",
" \"Rhinolophus ferrumequinum\",\n", " \"Rhinolophus ferrumequinum\",\n",
" \"social\",\n",
" ]\n", " ]\n",
"\n", "\n",
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n", " def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
@ -197,11 +205,11 @@
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0", "id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:26.104877Z", "iopub.execute_input": "2024-11-19T17:33:09.262337Z",
"iopub.status.busy": "2024-07-16T00:27:26.104538Z", "iopub.status.busy": "2024-11-19T17:33:09.261775Z",
"iopub.status.idle": "2024-07-16T00:27:26.159676Z", "iopub.status.idle": "2024-11-19T17:33:09.309793Z",
"shell.execute_reply": "2024-07-16T00:27:26.157914Z", "shell.execute_reply": "2024-11-19T17:33:09.309216Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.104843Z" "shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
} }
}, },
"outputs": [], "outputs": [],
@ -215,11 +223,11 @@
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06", "id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:26.162346Z", "iopub.execute_input": "2024-11-19T17:33:09.310695Z",
"iopub.status.busy": "2024-07-16T00:27:26.161885Z", "iopub.status.busy": "2024-11-19T17:33:09.310438Z",
"iopub.status.idle": "2024-07-16T00:27:26.374668Z", "iopub.status.idle": "2024-11-19T17:33:09.366636Z",
"shell.execute_reply": "2024-07-16T00:27:26.373691Z", "shell.execute_reply": "2024-11-19T17:33:09.366059Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.162305Z" "shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
} }
}, },
"outputs": [ "outputs": [
@ -229,7 +237,6 @@
"text": [ "text": [
"GPU available: False, used: False\n", "GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n", "TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n" "HPU available: False, using: 0 HPUs\n"
] ]
} }
@ -248,11 +255,11 @@
"id": "0b86d49d-3314-4257-94f5-f964855be385", "id": "0b86d49d-3314-4257-94f5-f964855be385",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:26.375918Z", "iopub.execute_input": "2024-11-19T17:33:09.367499Z",
"iopub.status.busy": "2024-07-16T00:27:26.375632Z", "iopub.status.busy": "2024-11-19T17:33:09.367242Z",
"iopub.status.idle": "2024-07-16T00:27:28.829650Z", "iopub.status.idle": "2024-11-19T17:33:10.811300Z",
"shell.execute_reply": "2024-07-16T00:27:28.828219Z", "shell.execute_reply": "2024-11-19T17:33:10.809823Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.375889Z" "shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
} }
}, },
"outputs": [ "outputs": [
@ -261,37 +268,67 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\n", "\n",
" | Name | Type | Params\n", " | Name | Type | Params | Mode \n",
"------------------------------------------------\n", "--------------------------------------------------------\n",
"0 | feature_extractor | Net2DFast | 119 K \n", "0 | feature_extractor | Net2DFast | 119 K | train\n",
"1 | classifier | Conv2d | 54 \n", "1 | classifier | Conv2d | 54 | train\n",
"2 | bbox | Conv2d | 18 \n", "2 | bbox | Conv2d | 18 | train\n",
"------------------------------------------------\n", "--------------------------------------------------------\n",
"119 K Trainable params\n", "119 K Trainable params\n",
"448 Non-trainable params\n", "448 Non-trainable params\n",
"119 K Total params\n", "119 K Total params\n",
"0.480 Total estimated model params size (MB)\n" "0.480 Total estimated model params size (MB)\n",
"32 Modules in train mode\n",
"0 Modules in eval mode\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.59it/s, v_num=13]" "Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
"class props shape torch.Size([3, 5, 128, 512])\n"
] ]
}, },
{ {
"name": "stderr", "ename": "RuntimeError",
"output_type": "stream", "evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
"text": [ "output_type": "error",
"`Trainer.fit` stopped: `max_epochs=2` reached.\n" "traceback": [
] "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
}, "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
{ "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
"name": "stdout", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"output_type": "stream", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
"text": [ "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.54it/s, v_num=13]\n" "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
] ]
} }
], ],
@ -301,15 +338,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": null,
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314", "id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:28.832222Z", "iopub.status.busy": "2024-11-19T17:33:10.811729Z",
"iopub.status.busy": "2024-07-16T00:27:28.831642Z", "iopub.status.idle": "2024-11-19T17:33:10.811955Z",
"iopub.status.idle": "2024-07-16T00:27:29.000595Z", "shell.execute_reply": "2024-11-19T17:33:10.811858Z",
"shell.execute_reply": "2024-07-16T00:27:28.998078Z", "shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
"shell.execute_reply.started": "2024-07-16T00:27:28.832157Z"
} }
}, },
"outputs": [], "outputs": [],
@ -319,44 +355,54 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": null,
"id": "23943e13-6875-49b8-9f18-2ba6528aa673", "id": "23943e13-6875-49b8-9f18-2ba6528aa673",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:27:29.004279Z", "iopub.status.busy": "2024-11-19T17:33:10.812924Z",
"iopub.status.busy": "2024-07-16T00:27:29.003486Z", "iopub.status.idle": "2024-11-19T17:33:10.813260Z",
"iopub.status.idle": "2024-07-16T00:27:29.595626Z", "shell.execute_reply": "2024-11-19T17:33:10.813104Z",
"shell.execute_reply": "2024-07-16T00:27:29.594734Z", "shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
"shell.execute_reply.started": "2024-07-16T00:27:29.004200Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"predictions = detector.compute_clip_predictions(clip_annotation.clip)" "spec = detector.compute_spectrogram(clip_annotation.clip)\n",
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": null,
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
}
},
"outputs": [],
"source": [
"_, ax= plt.subplots(figsize=(15, 5))\n",
"spec.plot(ax=ax, add_colorbar=False)\n",
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b", "id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-07-16T00:28:47.178783Z", "iopub.status.busy": "2024-11-19T17:33:10.815603Z",
"iopub.status.busy": "2024-07-16T00:28:47.178143Z", "iopub.status.idle": "2024-11-19T17:33:10.816065Z",
"iopub.status.idle": "2024-07-16T00:28:47.246613Z", "shell.execute_reply": "2024-11-19T17:33:10.815894Z",
"shell.execute_reply": "2024-07-16T00:28:47.245496Z", "shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
"shell.execute_reply.started": "2024-07-16T00:28:47.178729Z"
} }
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num predicted soundevents: 50\n"
]
}
],
"source": [ "source": [
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")" "print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
] ]
@ -364,7 +410,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "d3883c04-d91a-4d1d-b677-196c0179dde1", "id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
@ -386,7 +432,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.18" "version": "3.12.5"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,103 +1,97 @@
[tool]
rye = { dev-dependencies = [
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"pytest>=8.1.1",
] }
[tool.pdm]
[tool.pdm.dev-dependencies]
dev = [
"pytest>=7.2.2",
]
[project] [project]
name = "batdetect2" name = "batdetect2"
version = "1.0.8" version = "1.1.1"
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings." description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
authors = [ authors = [
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" }, { "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" } { "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
] ]
dependencies = [ dependencies = [
"librosa>=0.10.1", "click>=8.1.7",
"matplotlib>=3.7.1", "librosa>=0.10.1",
"numpy>=1.23.5", "matplotlib>=3.7.1",
"pandas>=1.5.3", "numpy>=1.23.5",
"scikit-learn>=1.2.2", "pandas>=1.5.3",
"scipy>=1.10.1", "scikit-learn>=1.2.2",
"torch>=1.13.1", "scipy>=1.10.1",
"torchaudio", "torch>=1.13.1,<2.5.0",
"torchvision", "torchaudio>=1.13.1,<2.5.0",
"soundevent[audio,geometry,plot]>=2.0.1", "torchvision>=0.14.0",
"click>=8.1.7", "soundevent[audio,geometry,plot]>=2.3",
"netcdf4>=1.6.5", "click>=8.1.7",
"tqdm>=4.66.2", "netcdf4>=1.6.5",
"pytorch-lightning>=2.2.2", "tqdm>=4.66.2",
"cf-xarray>=0.9.0", "pytorch-lightning>=2.2.2",
"onnx>=1.16.0", "cf-xarray>=0.9.0",
"lightning[extra]>=2.2.2", "onnx>=1.16.0",
"tensorboard>=2.16.2", "lightning[extra]>=2.2.2",
"tensorboard>=2.16.2",
"omegaconf>=2.3.0",
"pyyaml>=6.0.2",
"hydra-core>=1.3.2",
"numba>=0.60",
] ]
requires-python = ">=3.9" requires-python = ">=3.9,<3.13"
readme = "README.md" readme = "README.md"
license = { text = "CC-by-nc-4" } license = { text = "CC-by-nc-4" }
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Natural Language :: English", "Natural Language :: English",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Programming Language :: Python :: 3.11",
"Topic :: Software Development :: Libraries :: Python Modules", "Programming Language :: Python :: 3.12",
"Topic :: Multimedia :: Sound/Audio :: Analysis", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
] ]
keywords = [ keywords = [
"bat", "bat",
"echolocation", "echolocation",
"deep learning", "deep learning",
"audio", "audio",
"machine learning", "machine learning",
"classification", "classification",
"detection", "detection",
] ]
[build-system] [build-system]
requires = ["pdm-pep517>=1.0.0"] requires = ["hatchling"]
build-backend = "pdm.pep517.api" build-backend = "hatchling.build"
[project.scripts] [project.scripts]
batdetect2 = "batdetect2.cli:cli" batdetect2 = "batdetect2.cli:cli"
[tool.black] [tool.uv]
line-length = 79 dev-dependencies = [
"debugpy>=1.8.8",
[tool.isort] "hypothesis>=6.118.7",
profile = "black" "pytest>=7.2.2",
line_length = 79 "ruff>=0.7.3",
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"basedpyright>=1.28.4",
]
[tool.ruff] [tool.ruff]
line-length = 79 line-length = 79
target-version = "py39"
[[tool.mypy.overrides]] [tool.ruff.format]
module = [ docstring-code-format = true
"librosa", docstring-code-line-length = 79
"pandas",
]
ignore_missing_imports = true
[tool.pylsp-mypy] [tool.ruff.lint]
enabled = false select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
live_mode = true
strict = true
[tool.pydocstyle] [tool.ruff.lint.pydocstyle]
convention = "numpy" convention = "numpy"
[tool.pyright] [tool.pyright]
include = [ include = ["batdetect2", "tests"]
"bat_detect",
"tests",
]
venvPath = "." venvPath = "."
venv = ".venv" venv = ".venv"
pythonVersion = "3.9"
pythonPlatform = "All"

View File

@ -16,7 +16,6 @@ import batdetect2.train.train_utils as tu
import batdetect2.utils.audio_utils as au import batdetect2.utils.audio_utils as au
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"audio_path", type=str, help="Input directory for audio" "audio_path", type=str, help="Input directory for audio"
@ -65,7 +64,9 @@ if __name__ == "__main__":
else: else:
# load uk data - special case # load uk data - special case
print("\nLoading:", args["uk_split"], "\n") print("\nLoading:", args["uk_split"], "\n")
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same dataset_name = (
"uk_" + args["uk_split"]
) # should be uk_diff, or uk_same
datasets, _ = ts.get_train_test_data( datasets, _ = ts.get_train_test_data(
args["ann_file"], args["ann_file"],
args["audio_path"], args["audio_path"],

View File

@ -33,7 +33,6 @@ def filter_anns(anns, start_time, stop_time):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to audio file") parser.add_argument("audio_file", type=str, help="Path to audio file")
parser.add_argument("model_path", type=str, help="Path to BatDetect model") parser.add_argument("model_path", type=str, help="Path to BatDetect model")
@ -143,7 +142,9 @@ if __name__ == "__main__":
# run model and filter detections so only keep ones in relevant time range # run model and filter detections so only keep ones in relevant time range
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = du.process_file(args_cmd["audio_file"], model, run_config, device) results = du.process_file(
args_cmd["audio_file"], model, run_config, device
)
pred_anns = filter_anns( pred_anns = filter_anns(
results["pred_dict"]["annotation"], results["pred_dict"]["annotation"],
args_cmd["start_time"], args_cmd["start_time"],

View File

@ -25,7 +25,9 @@ import batdetect2.utils.plot_utils as viz
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to input audio file") parser.add_argument(
"audio_file", type=str, help="Path to input audio file"
)
parser.add_argument( parser.add_argument(
"model_path", type=str, help="Path to trained BatDetect model" "model_path", type=str, help="Path to trained BatDetect model"
) )

View File

@ -198,7 +198,6 @@ def save_summary_image(
) )
ii = 0 ii = 0
for row in ax: for row in ax:
if type(row) != np.ndarray: if type(row) != np.ndarray:
row = np.array([row]) row = np.array([row])
@ -215,7 +214,9 @@ def save_summary_image(
) )
col.grid(color="w", alpha=0.3, linewidth=0.3) col.grid(color="w", alpha=0.3, linewidth=0.3)
col.set_xticks([]) col.set_xticks([])
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]]) col.title.set_text(
str(ii + 1) + " " + species_names[order[ii]]
)
col.tick_params(axis="both", which="major", labelsize=7) col.tick_params(axis="both", which="major", labelsize=7)
ii += 1 ii += 1

109
tests/conftest.py Normal file
View File

@ -0,0 +1,109 @@
import uuid
from pathlib import Path
from typing import Callable, List, Optional
import numpy as np
import pytest
import soundfile as sf
from soundevent import data
@pytest.fixture
def example_data_dir() -> Path:
pkg_dir = Path(__file__).parent.parent
example_data_dir = pkg_dir / "example_data"
assert example_data_dir.exists()
return example_data_dir
@pytest.fixture
def example_audio_dir(example_data_dir: Path) -> Path:
example_audio_dir = example_data_dir / "audio"
assert example_audio_dir.exists()
return example_audio_dir
@pytest.fixture
def example_anns_dir(example_data_dir: Path) -> Path:
example_anns_dir = example_data_dir / "anns"
assert example_anns_dir.exists()
return example_anns_dir
@pytest.fixture
def example_audio_files(example_audio_dir: Path) -> List[Path]:
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
assert len(audio_files) == 3
return audio_files
@pytest.fixture
def data_dir() -> Path:
dir = Path(__file__).parent / "data"
assert dir.exists()
return dir
@pytest.fixture
def contrib_dir(data_dir) -> Path:
dir = data_dir / "contrib"
assert dir.exists()
return dir
@pytest.fixture
def wav_factory(tmp_path: Path):
def _wav_factory(
path: Optional[Path] = None,
duration: float = 0.3,
channels: int = 1,
samplerate: int = 441_000,
bit_depth: int = 16,
) -> Path:
path = path or tmp_path / f"{uuid.uuid4()}.wav"
frames = int(samplerate * duration)
shape = (frames, channels)
subtype = f"PCM_{bit_depth}"
if bit_depth == 16:
dtype = np.int16
elif bit_depth == 32:
dtype = np.int32
else:
raise ValueError(f"Unsupported bit depth: {bit_depth}")
wav = np.random.uniform(
low=np.iinfo(dtype).min,
high=np.iinfo(dtype).max,
size=shape,
).astype(dtype)
sf.write(str(path), wav, samplerate, subtype=subtype)
return path
return _wav_factory
@pytest.fixture
def recording_factory(wav_factory: Callable[..., Path]):
def _recording_factory(
tags: Optional[list[data.Tag]] = None,
path: Optional[Path] = None,
recording_id: Optional[uuid.UUID] = None,
duration: float = 1,
channels: int = 1,
samplerate: int = 256_000,
time_expansion: float = 1,
) -> data.Recording:
path = path or wav_factory(
duration=duration,
channels=channels,
samplerate=samplerate,
)
return data.Recording.from_file(
path=path,
uuid=recording_id or uuid.uuid4(),
time_expansion=time_expansion,
tags=tags or [],
)
return _recording_factory

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,14 +1,13 @@
"""Test bat detect module API.""" """Test bat detect module API."""
from pathlib import Path
import os import os
from glob import glob from glob import glob
from pathlib import Path
import numpy as np import numpy as np
import soundfile as sf
import torch import torch
from torch import nn from torch import nn
import soundfile as sf
from batdetect2 import api from batdetect2 import api
@ -267,7 +266,6 @@ def test_process_file_with_spec_slices():
assert len(results["spec_slices"]) == len(detections) assert len(results["spec_slices"]) == len(detections)
def test_process_file_with_empty_predictions_does_not_fail( def test_process_file_with_empty_predictions_does_not_fail(
tmp_path: Path, tmp_path: Path,
): ):

Some files were not shown because too many files have changed in this diff Show More