Merge pull request #4 from macaodha/dev

Merging the Python API
This commit is contained in:
Oisin Mac Aodha 2023-04-07 15:06:10 +01:00 committed by GitHub
commit 7c80441d60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 8014 additions and 2321 deletions

4
.git-blame-ignore-revs Normal file
View File

@ -0,0 +1,4 @@
# Format code with Black and isort
3c17a2337166245de8df778fe174aad997e14e8f
9cb6b20949c7c31ee21ed2b800e8b691f1be32a7
53100f51e083cf4d900ed325ae0543cc754a26cc

104
.gitignore vendored
View File

@ -1,10 +1,108 @@
*.pyc # Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# IPython
profile_default/
ipython_config.py
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Rope project settings
.ropeproject/
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Model artifacts
*.png *.png
*.jpg *.jpg
*.wav *.wav
*.tar *.tar
*.json *.json
*.ipynb_checkpoints/
experiments/*
plots/* plots/*
# Batdetect Models [Include]
!bat_detect/models/*.pth.tar
# Model experiments
experiments/*
# Jupiter notebooks
.virtual_documents
.ipynb_checkpoints
*.ipynb
!batdetect2_notebook.ipynb

5
.pylintrc Normal file
View File

@ -0,0 +1,5 @@
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=torch.*

129
README.md
View File

@ -1,61 +1,122 @@
# BatDetect2 # BatDetect2
<img align="left" width="64" height="64" src="ims/bat_icon.png"> <img style="display: block-inline;" width="64" height="64" src="ims/bat_icon.png"> Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
Code for detecting and classifying bat echolocation calls in high frequency audio recordings. ## Getting started
### Python Environment
We recommend using an isolated Python environment to avoid dependency issues. Choose one
of the following options:
* Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it:
```bash
conda create -y --name batdetect2 python==3.10
conda activate batdetect2
```
* If you already have Python installed (version >= 3.8,< 3.11) and prefer using virtual environments then:
```bash
python -m venv .venv
source .venv/bin/activate
```
### Installing BatDetect2
You can use pip to install `batdetect2`:
```bash
pip install batdetect2
```
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
Once unziped, run this from extracted folder.
```bash
pip install .
```
Make sure you have the environment activated before installing `batdetect2`.
### Getting started ## Try the model
1) Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). 1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
2) Download this code from the repository (by clicking on the green button on top right) and unzip it.
3) Create a new environment and install the required packages: 2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
`conda env create -f environment.yml`
`conda activate batdetect2`
### Try the model ## Running the model on your own data
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally. After following the above steps to install the code you can run the model on your own data.
### Running the model on your own data ### Using the command line
After following the above steps to install the code you can run the model on your own data by opening the command line where the code is located and typing:
`python run_batdetect.py AUDIO_DIR ANN_DIR DETECTION_THRESHOLD` You can run the model by opening the command line and typing:
e.g. ```bash
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3` batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
```
e.g.
```bash
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
```
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features`
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
`AUDIO_DIR` is the path on your computer to the audio wav files of interest. ### Using the Python API
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files: If you prefer to process your data within a Python script then you can use the `batdetect2` Python API.
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --spec_features`
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g. ```python
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar` from batdetect2 import api
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
# Process a whole file
results = api.process_file(AUDIO_FILE)
# Or, load audio and compute spectrograms
audio = api.load_audio(AUDIO_FILE)
spec = api.generate_spectrogram(audio)
# And process the audio or the spectrogram with the model
detections, features, spec = api.process_audio(audio)
detections, features = api.process_spectrogram(spec)
# Do something else ...
```
You can integrate the detections or the extracted features to your custom analysis pipeline.
### Training the model on your own data ## Training the model on your own data
Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model. Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model.
### Data and annotations ## Data and annotations
The raw audio data and annotations used to train the models in the paper will be added soon. The raw audio data and annotations used to train the models in the paper will be added soon.
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI). The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
### Warning ## Warning
The models developed and shared as part of this repository should be used with caution. The models developed and shared as part of this repository should be used with caution.
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment. While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted. Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
### FAQ ## FAQ
For more information please consult our [FAQ](faq.md). For more information please consult our [FAQ](faq.md).
### Reference ## Reference
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1): If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
``` ```
@article{batdetect2_2022, @article{batdetect2_2022,
@ -66,8 +127,8 @@ If you find our work useful in your research please consider citing our paper wh
} }
``` ```
### Acknowledgements ## Acknowledgements
Thanks to all the contributors who spent time collecting and annotating audio data. Thanks to all the contributors who spent time collecting and annotating audio data.
### TODOs ### TODOs

170
app.py
View File

@ -1,84 +1,126 @@
import gradio as gr import gradio as gr
import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd
import numpy as np import numpy as np
import pandas as pd
import bat_detect.utils.detector_utils as du
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz import bat_detect.utils.plot_utils as viz
# setup the arguments # setup the arguments
args = {} args = {}
args = du.get_default_bd_args() args = du.get_default_run_config()
args['detection_threshold'] = 0.3 args["detection_threshold"] = 0.3
args['time_expansion_factor'] = 1 args["time_expansion_factor"] = 1
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar' args["model_path"] = "models/Net2DFast_UK_same.pth.tar"
max_duration = 2.0 max_duration = 2.0
# load the model # load the model
model, params = du.load_model(args['model_path']) model, params = du.load_model(args["model_path"])
df = gr.Dataframe( df = gr.Dataframe(
headers=["species", "time", "detection_prob", "species_prob"], headers=["species", "time", "detection_prob", "species_prob"],
datatype=["str", "str", "str", "str"], datatype=["str", "str", "str", "str"],
row_count=1, row_count=1,
col_count=(4, "fixed"), col_count=(4, "fixed"),
label='Predictions' label="Predictions",
) )
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3], examples = [
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3], ["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3],
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]] ["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3],
["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3],
]
def make_prediction(file_name=None, detection_threshold=0.3): def make_prediction(file_name=None, detection_threshold=0.3):
if file_name is not None: if file_name is not None:
audio_file = file_name audio_file = file_name
else: else:
return "You must provide an input audio file." return "You must provide an input audio file."
if detection_threshold is not None and detection_threshold != '': if detection_threshold is not None and detection_threshold != "":
args['detection_threshold'] = float(detection_threshold) args["detection_threshold"] = float(detection_threshold)
run_config = {
**params,
**args,
"max_duration": max_duration,
}
# process the file to generate predictions # process the file to generate predictions
results = du.process_file(audio_file, model, params, args, max_duration=max_duration) results = du.process_file(
audio_file,
anns = [ann for ann in results['pred_dict']['annotation']] model,
clss = [aa['class'] for aa in anns] run_config,
st_time = [aa['start_time'] for aa in anns] )
cls_prob = [aa['class_prob'] for aa in anns]
det_prob = [aa['det_prob'] for aa in anns] anns = [ann for ann in results["pred_dict"]["annotation"]]
data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob} clss = [aa["class"] for aa in anns]
st_time = [aa["start_time"] for aa in anns]
cls_prob = [aa["class_prob"] for aa in anns]
det_prob = [aa["det_prob"] for aa in anns]
data = {
"species": clss,
"time": st_time,
"detection_prob": det_prob,
"species_prob": cls_prob,
}
df = pd.DataFrame(data=data) df = pd.DataFrame(data=data)
im = generate_results_image(audio_file, anns) im = generate_results_image(audio_file, anns)
return [df, im] return [df, im]
def generate_results_image(audio_file, anns): def generate_results_image(audio_file, anns):
# load audio # load audio
sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], sampling_rate, audio = au.load_audio(
params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration) audio_file,
args["time_expansion_factor"],
params["target_samp_rate"],
params["scale_raw_audio"],
max_duration=max_duration,
)
duration = audio.shape[0] / sampling_rate duration = audio.shape[0] / sampling_rate
# generate spec # generate spec
spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False) spec, spec_viz = au.generate_spectrogram(
audio, sampling_rate, params, True, False
)
# create fig # create fig
plt.close('all') plt.close("all")
fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False) fig = plt.figure(
spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) 1,
viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True) figsize=(spec.shape[1] / 100, spec.shape[0] / 100),
plt.ylabel('Freq - kHz') dpi=100,
plt.xlabel('Time - secs') frameon=False,
)
spec_duration = au.x_coords_to_time(
spec.shape[1],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
viz.create_box_image(
spec,
fig,
anns,
0,
spec_duration,
spec_duration,
params,
spec.max() * 1.1,
False,
True,
)
plt.ylabel("Freq - kHz")
plt.xlabel("Time - secs")
plt.tight_layout() plt.tight_layout()
# convert fig to image # convert fig to image
fig.canvas.draw() fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
@ -88,21 +130,23 @@ def generate_results_image(audio_file, anns):
return im return im
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \ descr_txt = (
"<br>This model is only trained on bat species from the UK. If the input " \ "Demo of BatDetect2 deep learning-based bat echolocation call detection. "
"file is longer than 2 seconds, only the first 2 seconds will be processed." \ "<br>This model is only trained on bat species from the UK. If the input "
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." "file is longer than 2 seconds, only the first 2 seconds will be processed."
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
)
gr.Interface( gr.Interface(
fn = make_prediction, fn=make_prediction,
inputs = [gr.Audio(source="upload", type="filepath", optional=True), inputs=[
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])], gr.Audio(source="upload", type="filepath", optional=True),
outputs = [df, gr.Image(label="Visualisation")], gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
theme = "huggingface", ],
title = "BatDetect2 Demo", outputs=[df, gr.Image(label="Visualisation")],
description = descr_txt, theme="huggingface",
examples = examples, title="BatDetect2 Demo",
allow_flagging = 'never', description=descr_txt,
examples=examples,
allow_flagging="never",
).launch() ).launch()

397
bat_detect/api.py Normal file
View File

@ -0,0 +1,397 @@
"""Python API for bat_detect.
This module provides a Python API for bat_detect. It can be used to
process audio files or spectrograms with the default model or a custom
model.
Example
-------
You can use the default model to process audio files. To process a single
file, use the `process_file` function.
>>> import bat_detect.api as api
>>> # Process audio file
>>> results = api.process_file("audio_file.wav")
To process multiple files, use the `list_audio_files` function to get a list
of audio files in a directory. Then use the `process_file` function to
process each file.
>>> import bat_detect.api as api
>>> # Get list of audio files
>>> audio_files = api.list_audio_files("audio_directory")
>>> # Process audio files
>>> results = [api.process_file(f) for f in audio_files]
The `process_file` function will slice the recording into 3 second chunks
and process each chunk separately, in case the recording is longer. The
results will be combined into a dictionary with the following keys:
- `pred_dict`: All the predictions from the model in the format
expected by the annotation tool.
- `cnn_feats`: Optional. A list of `numpy` arrays containing the CNN features
for each detection. The CNN features are the output of the CNN before
the final classification layer. You can use these features to train
your own classifier, or to do other processing on the detections.
They are in the same order as the detections in
`results['pred_dict']['annotation']`. Will only be returned if the
`cnn_feats` parameter in the config is set to `True`.
- `spec_slices`: Optional. A list of `numpy` arrays containing the spectrogram
for each of the processed chunks. Will only be returned if the
`spec_slices` parameter in the config is set to `True`.
Alternatively, you can use the `process_audio` function to process an audio
array directly, or `process_spectrogram` to process spectrograms. This
allows you to do other preprocessing steps before running the model for
predictions.
>>> import bat_detect.api as api
>>> # Load audio
>>> audio = api.load_audio("audio_file.wav")
>>> # Process the audio array
>>> detections, features, spec = api.process_audio(audio)
>>> # Or compute and process the spectrogram
>>> spec = api.generate_spectrogram(audio)
>>> detections, features = api.process_spectrogram(spec)
Here `detections` is the list of detected calls, `features` is the list of
CNN features for each detection, and `spec` is the spectrogram of the
processed audio. Each detection is a dictionary similary to the
following:
{
'start_time': 0.0,
'end_time': 0.1,
'low_freq': 10000,
'high_freq': 20000,
'class': 'Myotis myotis',
'class_prob': 0.9,
'det_prob': 0.9,
'individual': 0,
'event': 'Echolocation'
}
If you wish to interact directly with the model, you can use the `model`
attribute to get the default model.
>>> import bat_detect.api as api
>>> # Get the default model
>>> model = api.model
>>> # Process the spectrogram
>>> outputs = model(spec)
However, you will need to do the postprocessing yourself. The
model outputs are a collection of raw tensors. The `postprocess`
function can be used to convert the model outputs into a list of
detections and a list of CNN features.
>>> import bat_detect.api as api
>>> # Get the default model
>>> model = api.model
>>> # Process the spectrogram
>>> outputs = model(spec)
>>> # Postprocess the outputs
>>> detections, features = api.postprocess(outputs)
If you wish to use a custom model or change the default parameters, please
consult the API documentation in the code.
"""
import warnings
from typing import List, Optional, Tuple
import numpy as np
import torch
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
from bat_detect.detector.parameters import (
DEFAULT_MODEL_PATH,
DEFAULT_PROCESSING_CONFIGURATIONS,
DEFAULT_SPECTROGRAM_PARAMETERS,
TARGET_SAMPLERATE_HZ,
)
from bat_detect.types import (
Annotation,
DetectionModel,
ModelOutput,
ProcessingConfiguration,
SpectrogramParameters,
)
from bat_detect.utils.detector_utils import list_audio_files, load_model
# Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
__all__ = [
"config",
"generate_spectrogram",
"get_config",
"list_audio_files",
"load_audio",
"load_model",
"model",
"postprocess",
"process_audio",
"process_file",
"process_spectrogram",
]
# Use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Default model
MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE)
def get_config(**kwargs) -> ProcessingConfiguration:
"""Get default processing configuration.
Can be used to override default parameters by passing keyword arguments.
"""
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
# Default processing configuration
CONFIG = get_config(**PARAMS)
def load_audio(
path: str,
time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False,
max_duration: Optional[float] = None,
) -> np.ndarray:
"""Load audio from file.
All audio will be resampled to the target sample rate. If the audio is
longer than max_duration, it will be truncated to max_duration.
Parameters
----------
path : str
Path to audio file.
time_exp_fact : float, optional
Time expansion factor, by default 1
target_samp_rate : int, optional
Target sample rate, by default 256000
scale : bool, optional
Scale audio to [-1, 1], by default False
max_duration : float, optional
Maximum duration of audio in seconds, by default None
Returns
-------
np.ndarray
Audio data.
"""
_, audio = au.load_audio(
path,
time_exp_fact,
target_samp_rate,
scale,
max_duration,
)
return audio
def generate_spectrogram(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[SpectrogramParameters] = None,
device: torch.device = DEVICE,
) -> torch.Tensor:
"""Generate spectrogram from audio array.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int, optional
Sample rate. Defaults to 256000 which is the target sample rate of
the default model. Only change if you loaded the audio with a
different sample rate.
config : SpectrogramParameters, optional
Spectrogram parameters, by default None (uses default parameters).
Returns
-------
torch.Tensor
Spectrogram.
"""
if config is None:
config = DEFAULT_SPECTROGRAM_PARAMETERS
_, spec, _ = du.compute_spectrogram(
audio,
samp_rate,
config,
return_np=False,
device=device,
)
return spec
def process_file(
audio_file: str,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
audio_file : str
Path to audio file.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
"""
if config is None:
config = CONFIG
return du.process_file(
audio_file,
model,
config,
device,
)
def process_spectrogram(
spec: torch.Tensor,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process spectrogram with model.
Parameters
----------
spec : torch.Tensor
Spectrogram.
samp_rate : int, optional
Sample rate of the audio from which the spectrogram was generated.
Defaults to 256000 which is the target sample rate of the default
model. Only change if you generated the spectrogram with a different
sample rate.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
Returns
-------
DetectionResult
"""
if config is None:
config = CONFIG
return du.process_spectrogram(
spec,
samp_rate,
model,
config,
)
def process_audio(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
"""Process audio array with model.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int, optional
Sample rate, by default 256000. Only change if you loaded the audio
with a different sample rate.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
Returns
-------
annotations : List[Annotation]
List of predicted annotations.
features: List[np.ndarray]
List of extracted features for each annotation.
spec : torch.Tensor
Spectrogram of the audio used for prediction.
"""
if config is None:
config = CONFIG
return du.process_audio_array(
audio,
samp_rate,
model,
config,
device,
)
def postprocess(
outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], np.ndarray]:
"""Postprocess model outputs.
Convert model tensor outputs to predicted bounding boxes and
extracted features.
Will run non-maximum suppression and remove overlapping annotations.
Parameters
----------
outputs : ModelOutput
Model raw outputs.
samp_rate : int, Optional
Sample rate of the audio from which the spectrogram was generated.
Defaults to 256000 which is the target sample rate of the default
model. Only change if you generated outputs from a spectrogram with
sample rate.
config : Optional[ProcessingConfiguration], Optional
Processing configuration, by default None (uses default parameters).
Returns
-------
annotations : List[Annotation]
List of predicted annotations.
features: np.ndarray
An array of extracted features for each annotation. The shape of the
array is (n_annotations, n_features).
"""
if config is None:
config = CONFIG
return du.postprocess_model_outputs(
outputs,
samp_rate,
config,
)
model: DetectionModel = MODEL
"""Base detection model."""
config: ProcessingConfiguration = CONFIG
"""Default processing configuration."""

137
bat_detect/cli.py Normal file
View File

@ -0,0 +1,137 @@
"""BatDetect2 command line interface."""
import os
import click
from bat_detect import api
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
from bat_detect.utils.detector_utils import save_results_to_file
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
INFO_STR = """
BatDetect2 - Detection and Classification
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
@click.group()
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
@cli.command()
@click.argument(
"audio_dir",
type=click.Path(exists=True),
)
@click.argument(
"ann_dir",
type=click.Path(exists=False),
)
@click.argument(
"detection_threshold",
type=float,
)
@click.option(
"--cnn_features",
is_flag=True,
default=False,
help="Extracts CNN call features",
)
@click.option(
"--spec_features",
is_flag=True,
default=False,
help="Extracts low level call features",
)
@click.option(
"--time_expansion_factor",
type=int,
default=1,
help="The time expansion factor used for all files (default is 1)",
)
@click.option(
"--quiet",
is_flag=True,
default=False,
help="Minimize output printing",
)
@click.option(
"--save_preds_if_empty",
is_flag=True,
default=False,
help="Save empty annotation file if no detections made.",
)
@click.option(
"--model_path",
type=str,
default=DEFAULT_MODEL_PATH,
help="Path to trained BatDetect2 model",
)
def detect(
audio_dir: str,
ann_dir: str,
detection_threshold: float,
**args,
):
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
DETECTION_THRESHOLD is the detection threshold. All predictions with a
score below this threshold will be discarded. Values between 0 and 1.
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"])
click.echo(f"\nInput directory: {audio_dir}")
files = api.list_audio_files(audio_dir)
click.echo(f"Number of audio files: {len(files)}")
click.echo(f"\nSaving results to: {ann_dir}")
config = api.get_config(
**{
**params,
**args,
"spec_slices": False,
"chunk_size": 2,
"detection_threshold": detection_threshold,
}
)
# process files
error_files = []
for audio_file in files:
try:
results = api.process_file(audio_file, model, config=config)
if args["save_preds_if_empty"] or (
len(results["pred_dict"]["annotation"]) > 0
):
results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err:
error_files.append(audio_file)
click.secho(f"Error processing file!: {err}", fg="red")
raise err
click.echo(f"\nResults saved to: {ann_dir}")
if len(error_files) > 0:
click.secho("\nUnable to process the follow files:", fg="red")
for err in error_files:
click.echo(f" {err}")
if __name__ == "__main__":
cli()

View File

@ -2,8 +2,10 @@ import numpy as np
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
spec_ind = spec_height-spec_ind spec_ind = spec_height - spec_ind
return round((spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2) return round(
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
)
def extract_spec_slices(spec, pred_nms, params): def extract_spec_slices(spec, pred_nms, params):
@ -11,28 +13,40 @@ def extract_spec_slices(spec, pred_nms, params):
Extracts spectrogram slices from spectrogram based on detected call locations. Extracts spectrogram slices from spectrogram based on detected call locations.
""" """
x_pos = pred_nms['x_pos'] x_pos = pred_nms["x_pos"]
y_pos = pred_nms['y_pos'] y_pos = pred_nms["y_pos"]
bb_width = pred_nms['bb_width'] bb_width = pred_nms["bb_width"]
bb_height = pred_nms['bb_height'] bb_height = pred_nms["bb_height"]
slices = [] slices = []
# add 20% padding either side of call # add 20% padding either side of call
pad = bb_width*0.2 pad = bb_width * 0.2
x_pos_pad = x_pos - pad x_pos_pad = x_pos - pad
bb_width_pad = bb_width + 2*pad bb_width_pad = bb_width + 2 * pad
for ff in range(len(pred_nms['det_probs'])): for ff in range(len(pred_nms["det_probs"])):
x_start = int(np.maximum(0, x_pos_pad[ff])) x_start = int(np.maximum(0, x_pos_pad[ff]))
x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos_pad[ff] + bb_width_pad[ff]))) x_end = int(
np.minimum(
spec.shape[1] - 1, np.round(x_pos_pad[ff] + bb_width_pad[ff])
)
)
slices.append(spec[:, x_start:x_end].astype(np.float16)) slices.append(spec[:, x_start:x_end].astype(np.float16))
return slices return slices
def get_feature_names(): def get_feature_names():
feature_names = ['duration', 'low_freq_bb', 'high_freq_bb', 'bandwidth', feature_names = [
'max_power_bb', 'max_power', 'max_power_first', "duration",
'max_power_second', 'call_interval'] "low_freq_bb",
"high_freq_bb",
"bandwidth",
"max_power_bb",
"max_power",
"max_power_first",
"max_power_second",
"call_interval",
]
return feature_names return feature_names
@ -45,40 +59,76 @@ def get_feats(spec, pred_nms, params):
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
""" """
x_pos = pred_nms['x_pos'] x_pos = pred_nms["x_pos"]
y_pos = pred_nms['y_pos'] y_pos = pred_nms["y_pos"]
bb_width = pred_nms['bb_width'] bb_width = pred_nms["bb_width"]
bb_height = pred_nms['bb_height'] bb_height = pred_nms["bb_height"]
feature_names = get_feature_names() feature_names = get_feature_names()
num_detections = len(pred_nms['det_probs']) num_detections = len(pred_nms["det_probs"])
features = np.ones((num_detections, len(feature_names)), dtype=np.float32)*-1 features = (
np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1
)
for ff in range(num_detections): for ff in range(num_detections):
x_start = int(np.maximum(0, x_pos[ff])) x_start = int(np.maximum(0, x_pos[ff]))
x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos[ff] + bb_width[ff]))) x_end = int(
np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff]))
)
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top # y low is the lowest freq but it will have a higher value due to array starting at 0 at top
y_low = int(np.minimum(spec.shape[0]-1, y_pos[ff])) y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff]))
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff]))) y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
spec_slice = spec[:, x_start:x_end] spec_slice = spec[:, x_start:x_end]
if spec_slice.shape[1] > 1: if spec_slice.shape[1] > 1:
features[ff, 0] = round(pred_nms['end_times'][ff] - pred_nms['start_times'][ff], 5) features[ff, 0] = round(
features[ff, 1] = int(pred_nms['low_freqs'][ff]) pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5
features[ff, 2] = int(pred_nms['high_freqs'][ff]) )
features[ff, 3] = int(pred_nms['high_freqs'][ff] - pred_nms['low_freqs'][ff]) features[ff, 1] = int(pred_nms["low_freqs"][ff])
features[ff, 4] = int(convert_int_to_freq(y_high+spec_slice[y_high:y_low, :].sum(1).argmax(), features[ff, 2] = int(pred_nms["high_freqs"][ff])
spec.shape[0], params['min_freq'], params['max_freq'])) features[ff, 3] = int(
features[ff, 5] = int(convert_int_to_freq(spec_slice.sum(1).argmax(), pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
spec.shape[0], params['min_freq'], params['max_freq'])) )
hlf_val = spec_slice.shape[1]//2 features[ff, 4] = int(
convert_int_to_freq(
y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
features[ff, 5] = int(
convert_int_to_freq(
spec_slice.sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
hlf_val = spec_slice.shape[1] // 2
features[ff, 6] = int(convert_int_to_freq(spec_slice[:, :hlf_val].sum(1).argmax(), features[ff, 6] = int(
spec.shape[0], params['min_freq'], params['max_freq'])) convert_int_to_freq(
features[ff, 7] = int(convert_int_to_freq(spec_slice[:, hlf_val:].sum(1).argmax(), spec_slice[:, :hlf_val].sum(1).argmax(),
spec.shape[0], params['min_freq'], params['max_freq'])) spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
features[ff, 7] = int(
convert_int_to_freq(
spec_slice[:, hlf_val:].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
if ff > 0: if ff > 0:
features[ff, 8] = round(pred_nms['start_times'][ff] - pred_nms['start_times'][ff-1], 5) features[ff, 8] = round(
pred_nms["start_times"][ff]
- pred_nms["start_times"][ff - 1],
5,
)
return features return features

View File

@ -1,8 +1,14 @@
import torch.nn as nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np from torch import nn
import math
__all__ = [
"SelfAttention",
"ConvBlockDownCoordF",
"ConvBlockDownStandard",
"ConvBlockUpF",
"ConvBlockUpStandard",
]
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
@ -10,38 +16,61 @@ class SelfAttention(nn.Module):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
# Note, does not encode position information (absolute or realtive) # Note, does not encode position information (absolute or realtive)
self.temperature = 1.0 self.temperature = 1.0
self.att_dim = att_dim self.att_dim = att_dim
self.key_fun = nn.Linear(ip_dim, att_dim) self.key_fun = nn.Linear(ip_dim, att_dim)
self.val_fun = nn.Linear(ip_dim, att_dim) self.val_fun = nn.Linear(ip_dim, att_dim)
self.que_fun = nn.Linear(ip_dim, att_dim) self.que_fun = nn.Linear(ip_dim, att_dim)
self.pro_fun = nn.Linear(att_dim, ip_dim) self.pro_fun = nn.Linear(att_dim, ip_dim)
def forward(self, x): def forward(self, x):
x = x.squeeze(2).permute(0,2,1) x = x.squeeze(2).permute(0, 2, 1)
kk = torch.matmul(x, self.key_fun.weight.T) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) kk = torch.matmul(
qq = torch.matmul(x, self.que_fun.weight.T) + self.que_fun.bias.unsqueeze(0).unsqueeze(0) x, self.key_fun.weight.T
vv = torch.matmul(x, self.val_fun.weight.T) + self.val_fun.bias.unsqueeze(0).unsqueeze(0) ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
qq = torch.matmul(
x, self.que_fun.weight.T
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
vv = torch.matmul(
x, self.val_fun.weight.T
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
kk_qq = torch.bmm(kk, qq.permute(0,2,1)) / (self.temperature*self.att_dim) kk_qq = torch.bmm(kk, qq.permute(0, 2, 1)) / (
att_weights = F.softmax(kk_qq, 1) # each col of each attention matrix sums to 1 self.temperature * self.att_dim
att = torch.bmm(vv.permute(0,2,1), att_weights) )
att_weights = F.softmax(
kk_qq, 1
) # each col of each attention matrix sums to 1
att = torch.bmm(vv.permute(0, 2, 1), att_weights)
op = torch.matmul(att.permute(0,2,1), self.pro_fun.weight.T) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0) op = torch.matmul(
op = op.permute(0,2,1).unsqueeze(2) att.permute(0, 2, 1), self.pro_fun.weight.T
) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
op = op.permute(0, 2, 1).unsqueeze(2)
return op return op
class ConvBlockDownCoordF(nn.Module): class ConvBlockDownCoordF(nn.Module):
def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1): def __init__(
self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
):
super(ConvBlockDownCoordF, self).__init__() super(ConvBlockDownCoordF, self).__init__()
self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height)[None, None, ..., None], requires_grad=False) self.coords = nn.Parameter(
self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size, stride=stride) torch.linspace(-1, 1, ip_height)[None, None, ..., None],
requires_grad=False,
)
self.conv = nn.Conv2d(
in_chn + 1,
out_chn,
kernel_size=k_size,
padding=pad_size,
stride=stride,
)
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x): def forward(self, x):
freq_info = self.coords.repeat(x.shape[0],1,1,x.shape[3]) freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
x = torch.cat((x, freq_info), 1) x = torch.cat((x, freq_info), 1)
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True) x = F.relu(self.conv_bn(x), inplace=True)
@ -49,9 +78,17 @@ class ConvBlockDownCoordF(nn.Module):
class ConvBlockDownStandard(nn.Module): class ConvBlockDownStandard(nn.Module):
def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1): def __init__(
self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
):
super(ConvBlockDownStandard, self).__init__() super(ConvBlockDownStandard, self).__init__()
self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size, stride=stride) self.conv = nn.Conv2d(
in_chn,
out_chn,
kernel_size=k_size,
padding=pad_size,
stride=stride,
)
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x): def forward(self, x):
@ -61,17 +98,41 @@ class ConvBlockDownStandard(nn.Module):
class ConvBlockUpF(nn.Module): class ConvBlockUpF(nn.Module):
def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)): def __init__(
self,
in_chn,
out_chn,
ip_height,
k_size=3,
pad_size=1,
up_mode="bilinear",
up_scale=(2, 2),
):
super(ConvBlockUpF, self).__init__() super(ConvBlockUpF, self).__init__()
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height*up_scale[0])[None, None, ..., None], requires_grad=False) self.coords = nn.Parameter(
self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size) torch.linspace(-1, 1, ip_height * up_scale[0])[
None, None, ..., None
],
requires_grad=False,
)
self.conv = nn.Conv2d(
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
)
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x): def forward(self, x):
op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False) op = F.interpolate(
freq_info = self.coords.repeat(op.shape[0],1,1,op.shape[3]) x,
size=(
x.shape[-2] * self.up_scale[0],
x.shape[-1] * self.up_scale[1],
),
mode=self.up_mode,
align_corners=False,
)
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
op = torch.cat((op, freq_info), 1) op = torch.cat((op, freq_info), 1)
op = self.conv(op) op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True) op = F.relu(self.conv_bn(op), inplace=True)
@ -79,15 +140,34 @@ class ConvBlockUpF(nn.Module):
class ConvBlockUpStandard(nn.Module): class ConvBlockUpStandard(nn.Module):
def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)): def __init__(
self,
in_chn,
out_chn,
ip_height=None,
k_size=3,
pad_size=1,
up_mode="bilinear",
up_scale=(2, 2),
):
super(ConvBlockUpStandard, self).__init__() super(ConvBlockUpStandard, self).__init__()
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size) self.conv = nn.Conv2d(
in_chn, out_chn, kernel_size=k_size, padding=pad_size
)
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x): def forward(self, x):
op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False) op = F.interpolate(
x,
size=(
x.shape[-2] * self.up_scale[0],
x.shape[-1] * self.up_scale[1],
),
mode=self.up_mode,
align_corners=False,
)
op = self.conv(op) op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True) op = F.relu(self.conv_bn(op), inplace=True)
return op return op

View File

@ -1,54 +1,109 @@
import torch.nn as nn
import torch import torch
import torch.nn.functional as F
import numpy as np
from .model_helpers import *
import torchvision
import torch.fft import torch.fft
import torch.nn.functional as F
from torch import nn from torch import nn
from bat_detect.detector.model_helpers import (
ConvBlockDownCoordF,
ConvBlockDownStandard,
ConvBlockUpF,
ConvBlockUpStandard,
SelfAttention,
)
from bat_detect.types import ModelOutput
__all__ = [
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
]
class Net2DFast(nn.Module): class Net2DFast(nn.Module):
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): def __init__(
super(Net2DFast, self).__init__() self,
num_filts,
num_classes=0,
emb_dim=0,
ip_height=128,
resize_factor=0.5,
):
super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.num_filts = num_filts self.num_filts = num_filts
self.resize_factor = resize_factor self.resize_factor = resize_factor
self.ip_height_rs = ip_height self.ip_height_rs = ip_height
self.bneck_height = self.ip_height_rs//32 self.bneck_height = self.ip_height_rs // 32
# encoder # encoder
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) self.conv_dn_0 = ConvBlockDownCoordF(
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) 1,
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) num_filts // 4,
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) self.ip_height_rs,
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
num_filts // 4,
num_filts // 2,
self.ip_height_rs // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
num_filts // 2,
num_filts,
self.ip_height_rs // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
# bottleneck # bottleneck
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) self.conv_1d = nn.Conv2d(
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) num_filts * 2,
self.att = SelfAttention(num_filts*2, num_filts*2) num_filts * 2,
(self.ip_height_rs // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
self.att = SelfAttention(num_filts * 2, num_filts * 2)
# decoder # decoder
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8) self.conv_up_2 = ConvBlockUpF(
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4) num_filts * 2, num_filts // 2, self.ip_height_rs // 8
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2) )
self.conv_up_3 = ConvBlockUpF(
num_filts // 2, num_filts // 4, self.ip_height_rs // 4
)
self.conv_up_4 = ConvBlockUpF(
num_filts // 4, num_filts // 4, self.ip_height_rs // 2
)
# output # output
# +1 to include background class for class output # +1 to include background class for class output
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) self.conv_op = nn.Conv2d(
self.conv_op_bn = nn.BatchNorm2d(num_filts//4) num_filts // 4, num_filts // 4, kernel_size=3, padding=1
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) )
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
self.conv_size_op = nn.Conv2d(
num_filts // 4, 2, kernel_size=1, padding=0
)
self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
)
if self.emb_dim > 0: if self.emb_dim > 0:
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) self.conv_emb = nn.Conv2d(
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False) -> ModelOutput:
def forward(self, ip, return_feats=False):
# encoder # encoder
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
@ -59,134 +114,218 @@ class Net2DFast(nn.Module):
# bottleneck # bottleneck
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x) x = self.att(x)
x = x.repeat([1,1,self.bneck_height*4,1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
# decoder # decoder
x = self.conv_up_2(x+x3) x = self.conv_up_2(x + x3)
x = self.conv_up_3(x+x2) x = self.conv_up_3(x + x2)
x = self.conv_up_4(x+x1) x = self.conv_up_4(x + x1)
# output # output
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
op = {} return ModelOutput(
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) pred_size=F.relu(self.conv_size_op(x), inplace=True),
op['pred_class'] = comb pred_class=comb,
op['pred_class_un_norm'] = cls pred_class_un_norm=cls,
if self.emb_dim > 0: features=x,
op['pred_emb'] = self.conv_emb(x) )
if return_feats:
op['features'] = x
return op
class Net2DFastNoAttn(nn.Module): class Net2DFastNoAttn(nn.Module):
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): def __init__(
super(Net2DFastNoAttn, self).__init__() self,
num_filts,
num_classes=0,
emb_dim=0,
ip_height=128,
resize_factor=0.5,
):
super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.num_filts = num_filts self.num_filts = num_filts
self.resize_factor = resize_factor self.resize_factor = resize_factor
self.ip_height_rs = ip_height self.ip_height_rs = ip_height
self.bneck_height = self.ip_height_rs//32 self.bneck_height = self.ip_height_rs // 32
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) self.conv_dn_0 = ConvBlockDownCoordF(
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) 1,
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) num_filts // 4,
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) self.ip_height_rs,
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
num_filts // 4,
num_filts // 2,
self.ip_height_rs // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
num_filts // 2,
num_filts,
self.ip_height_rs // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) self.conv_1d = nn.Conv2d(
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) num_filts * 2,
num_filts * 2,
(self.ip_height_rs // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
self.conv_up_2 = ConvBlockUpF(
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8) num_filts * 2, num_filts // 2, self.ip_height_rs // 8
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4) )
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2) self.conv_up_3 = ConvBlockUpF(
num_filts // 2, num_filts // 4, self.ip_height_rs // 4
)
self.conv_up_4 = ConvBlockUpF(
num_filts // 4, num_filts // 4, self.ip_height_rs // 2
)
# output # output
# +1 to include background class for class output # +1 to include background class for class output
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) self.conv_op = nn.Conv2d(
self.conv_op_bn = nn.BatchNorm2d(num_filts//4) num_filts // 4, num_filts // 4, kernel_size=3, padding=1
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) )
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
self.conv_size_op = nn.Conv2d(
num_filts // 4, 2, kernel_size=1, padding=0
)
self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
)
if self.emb_dim > 0: if self.emb_dim > 0:
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) self.conv_emb = nn.Conv2d(
num_filts, self.emb_dim, kernel_size=1, padding=0
def forward(self, ip, return_feats=False): )
def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = x.repeat([1,1,self.bneck_height*4,1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
x = self.conv_up_2(x+x3) x = self.conv_up_2(x + x3)
x = self.conv_up_3(x+x2) x = self.conv_up_3(x + x2)
x = self.conv_up_4(x+x1) x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
op = {} return ModelOutput(
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) pred_size=F.relu(self.conv_size_op(x), inplace=True),
op['pred_class'] = comb pred_class=comb,
op['pred_class_un_norm'] = cls pred_class_un_norm=cls,
if self.emb_dim > 0: features=x,
op['pred_emb'] = self.conv_emb(x) )
if return_feats:
op['features'] = x
return op
class Net2DFastNoCoordConv(nn.Module): class Net2DFastNoCoordConv(nn.Module):
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): def __init__(
super(Net2DFastNoCoordConv, self).__init__() self,
num_filts,
num_classes=0,
emb_dim=0,
ip_height=128,
resize_factor=0.5,
):
super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.num_filts = num_filts self.num_filts = num_filts
self.resize_factor = resize_factor self.resize_factor = resize_factor
self.ip_height_rs = ip_height self.ip_height_rs = ip_height
self.bneck_height = self.ip_height_rs//32 self.bneck_height = self.ip_height_rs // 32
self.conv_dn_0 = ConvBlockDownStandard(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) self.conv_dn_0 = ConvBlockDownStandard(
self.conv_dn_1 = ConvBlockDownStandard(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) 1,
self.conv_dn_2 = ConvBlockDownStandard(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) num_filts // 4,
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) self.ip_height_rs,
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownStandard(
num_filts // 4,
num_filts // 2,
self.ip_height_rs // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownStandard(
num_filts // 2,
num_filts,
self.ip_height_rs // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) self.conv_1d = nn.Conv2d(
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) num_filts * 2,
num_filts * 2,
(self.ip_height_rs // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
self.att = SelfAttention(num_filts*2, num_filts*2) self.att = SelfAttention(num_filts * 2, num_filts * 2)
self.conv_up_2 = ConvBlockUpStandard(num_filts*2, num_filts//2, self.ip_height_rs//8) self.conv_up_2 = ConvBlockUpStandard(
self.conv_up_3 = ConvBlockUpStandard(num_filts//2, num_filts//4, self.ip_height_rs//4) num_filts * 2, num_filts // 2, self.ip_height_rs // 8
self.conv_up_4 = ConvBlockUpStandard(num_filts//4, num_filts//4, self.ip_height_rs//2) )
self.conv_up_3 = ConvBlockUpStandard(
num_filts // 2, num_filts // 4, self.ip_height_rs // 4
)
self.conv_up_4 = ConvBlockUpStandard(
num_filts // 4, num_filts // 4, self.ip_height_rs // 2
)
# output # output
# +1 to include background class for class output # +1 to include background class for class output
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) self.conv_op = nn.Conv2d(
self.conv_op_bn = nn.BatchNorm2d(num_filts//4) num_filts // 4, num_filts // 4, kernel_size=3, padding=1
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) )
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
self.conv_size_op = nn.Conv2d(
num_filts // 4, 2, kernel_size=1, padding=0
)
self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
)
if self.emb_dim > 0: if self.emb_dim > 0:
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) self.conv_emb = nn.Conv2d(
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False): def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
@ -195,24 +334,21 @@ class Net2DFastNoCoordConv(nn.Module):
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x) x = self.att(x)
x = x.repeat([1,1,self.bneck_height*4,1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
x = self.conv_up_2(x+x3) x = self.conv_up_2(x + x3)
x = self.conv_up_3(x+x2) x = self.conv_up_3(x + x2)
x = self.conv_up_4(x+x1) x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
op = {} return ModelOutput(
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) pred_size=F.relu(self.conv_size_op(x), inplace=True),
op['pred_class'] = comb pred_class=comb,
op['pred_class_un_norm'] = cls pred_class_un_norm=cls,
if self.emb_dim > 0: pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
op['pred_emb'] = self.conv_emb(x) features=x,
if return_feats: )
op['features'] = x
return op

View File

@ -1,108 +1,235 @@
import numpy as np
import os
import datetime import datetime
import os
from bat_detect.types import (
ProcessingConfiguration,
SpectrogramParameters,
)
TARGET_SAMPLERATE_HZ = 256000
FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75
MAX_FREQ_HZ = 120000
MIN_FREQ_HZ = 10000
RESIZE_FACTOR = 0.5
SPEC_DIVIDE_FACTOR = 32
SPEC_HEIGHT = 256
SCALE_RAW_AUDIO = False
DETECTION_THRESHOLD = 0.01
NMS_KERNEL_SIZE = 9
NMS_TOP_K_PER_SEC = 200
SPEC_SCALE = "pcen"
DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"Net2DFast_UK_same.pth.tar",
)
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"spec_height": SPEC_HEIGHT,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
def mk_dir(path): def mk_dir(path):
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path) os.makedirs(path)
def get_params(make_dirs=False, exps_dir='../../experiments/'): def get_params(make_dirs=False, exps_dir="../../experiments/"):
params = {} params = {}
params['model_name'] = 'Net2DFast' # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN params[
params['num_filters'] = 128 "model_name"
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
params["num_filters"] = 128
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
model_name = now_str + '.pth.tar' model_name = now_str + ".pth.tar"
params['experiment'] = os.path.join(exps_dir, now_str, '') params["experiment"] = os.path.join(exps_dir, now_str, "")
params['model_file_name'] = os.path.join(params['experiment'], model_name) params["model_file_name"] = os.path.join(params["experiment"], model_name)
params['op_im_dir'] = os.path.join(params['experiment'], 'op_ims', '') params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
params['op_im_dir_test'] = os.path.join(params['experiment'], 'op_ims_test', '') params["op_im_dir_test"] = os.path.join(
#params['notes'] = '' # can save notes about an experiment here params["experiment"], "op_ims_test", ""
)
# params['notes'] = '' # can save notes about an experiment here
# spec parameters # spec parameters
params['target_samp_rate'] = 256000 # resamples all audio so that it is at this rate params[
params['fft_win_length'] = 512 / 256000.0 # in milliseconds, amount of time per stft time step "target_samp_rate"
params['fft_overlap'] = 0.75 # stft window overlap ] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
params[
"fft_win_length"
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
params['max_freq'] = 120000 # in Hz, everything above this will be discarded params[
params['min_freq'] = 10000 # in Hz, everything below this will be discarded "max_freq"
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
params[
"min_freq"
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
params['resize_factor'] = 0.5 # resize so the spectrogram at the input of the network params[
params['spec_height'] = 256 # units are number of frequency bins (before resizing is performed) "resize_factor"
params['spec_train_width'] = 512 # units are number of time steps (before resizing is performed) ] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
params['spec_divide_factor'] = 32 # spectrogram should be divisible by this amount in width and height params[
"spec_height"
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
params[
"spec_train_width"
] = 512 # units are number of time steps (before resizing is performed)
params[
"spec_divide_factor"
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
# spec processing params # spec processing params
params['denoise_spec_avg'] = True # removes the mean for each frequency band params[
params['scale_raw_audio'] = False # scales the raw audio to [-1, 1] "denoise_spec_avg"
params['max_scale_spec'] = False # scales the spectrogram so that it is max 1 ] = DENOISE_SPEC_AVG # removes the mean for each frequency band
params['spec_scale'] = 'pcen' # 'log', 'pcen', 'none' params[
"scale_raw_audio"
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
params[
"max_scale_spec"
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
# detection params # detection params
params['detection_overlap'] = 0.01 # has to be within this number of ms to count as detection params[
params['ignore_start_end'] = 0.01 # if start of GT calls are within this time from the start/end of file ignore "detection_overlap"
params['detection_threshold'] = 0.01 # the smaller this is the better the recall will be ] = 0.01 # has to be within this number of ms to count as detection
params['nms_kernel_size'] = 9 params[
params['nms_top_k_per_sec'] = 200 # keep top K highest predictions per second of audio "ignore_start_end"
params['target_sigma'] = 2.0 ] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
params[
"detection_threshold"
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
params[
"nms_kernel_size"
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
params[
"nms_top_k_per_sec"
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
params["target_sigma"] = 2.0
# augmentation params # augmentation params
params['aug_prob'] = 0.20 # augmentations will be performed with this probability params[
params['augment_at_train'] = True "aug_prob"
params['augment_at_train_combine'] = True ] = 0.20 # augmentations will be performed with this probability
params['echo_max_delay'] = 0.005 # simulate echo by adding copy of raw audio params["augment_at_train"] = True
params['stretch_squeeze_delta'] = 0.04 # stretch or squeeze spec params["augment_at_train_combine"] = True
params['mask_max_time_perc'] = 0.05 # max mask size - here percentage, not ideal params[
params['mask_max_freq_perc'] = 0.10 # max mask size - here percentage, not ideal "echo_max_delay"
params['spec_amp_scaling'] = 2.0 # multiply the "volume" by 0:X times current amount ] = 0.005 # simulate echo by adding copy of raw audio
params['aug_sampling_rates'] = [220500, 256000, 300000, 312500, 384000, 441000, 500000] params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
params[
"mask_max_time_perc"
] = 0.05 # max mask size - here percentage, not ideal
params[
"mask_max_freq_perc"
] = 0.10 # max mask size - here percentage, not ideal
params[
"spec_amp_scaling"
] = 2.0 # multiply the "volume" by 0:X times current amount
params["aug_sampling_rates"] = [
220500,
256000,
300000,
312500,
384000,
441000,
500000,
]
# loss params # loss params
params['train_loss'] = 'focal' # mse or focal params["train_loss"] = "focal" # mse or focal
params['det_loss_weight'] = 1.0 # weight for the detection part of the loss params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
params['size_loss_weight'] = 0.1 # weight for the bbox size loss params["size_loss_weight"] = 0.1 # weight for the bbox size loss
params['class_loss_weight'] = 2.0 # weight for the classification loss params["class_loss_weight"] = 2.0 # weight for the classification loss
params['individual_loss_weight'] = 0.0 # not used params["individual_loss_weight"] = 0.0 # not used
if params['individual_loss_weight'] == 0.0: if params["individual_loss_weight"] == 0.0:
params['emb_dim'] = 0 # number of dimensions used for individual id embedding params[
"emb_dim"
] = 0 # number of dimensions used for individual id embedding
else: else:
params['emb_dim'] = 3 params["emb_dim"] = 3
# train params # train params
params['lr'] = 0.001 params["lr"] = 0.001
params['batch_size'] = 8 params["batch_size"] = 8
params['num_workers'] = 4 params["num_workers"] = 4
params['num_epochs'] = 200 params["num_epochs"] = 200
params['num_eval_epochs'] = 5 # run evaluation every X epochs params["num_eval_epochs"] = 5 # run evaluation every X epochs
params['device'] = 'cuda' params["device"] = "cuda"
params['save_test_image_during_train'] = False params["save_test_image_during_train"] = False
params['save_test_image_after_train'] = True params["save_test_image_after_train"] = True
params['convert_to_genus'] = False params["convert_to_genus"] = False
params['genus_mapping'] = [] params["genus_mapping"] = []
params['class_names'] = [] params["class_names"] = []
params['classes_to_ignore'] = ['', ' ', 'Unknown', 'Not Bat'] params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
params['generic_class'] = ['Bat'] params["generic_class"] = ["Bat"]
params['events_of_interest'] = ['Echolocation'] # will ignore all other types of events e.g. social calls params["events_of_interest"] = [
"Echolocation"
] # will ignore all other types of events e.g. social calls
# the classes in this list are standardized during training so that the same low and high freq are used # the classes in this list are standardized during training so that the same low and high freq are used
params['standardize_classs_names'] = [] params["standardize_classs_names"] = []
# create directories # create directories
if make_dirs: if make_dirs:
print('Model name : ' + params['model_name']) print("Model name : " + params["model_name"])
print('Model file : ' + params['model_file_name']) print("Model file : " + params["model_file_name"])
print('Experiment : ' + params['experiment']) print("Experiment : " + params["experiment"])
mk_dir(params['experiment']) mk_dir(params["experiment"])
if params['save_test_image_during_train']: if params["save_test_image_during_train"]:
mk_dir(params['op_im_dir']) mk_dir(params["op_im_dir"])
if params['save_test_image_after_train']: if params["save_test_image_after_train"]:
mk_dir(params['op_im_dir_test']) mk_dir(params["op_im_dir_test"])
mk_dir(os.path.dirname(params['model_file_name'])) mk_dir(os.path.dirname(params["model_file_name"]))
return params return params

View File

@ -1,88 +1,168 @@
import torch """Post-processing of the output of the model."""
import torch.nn as nn from typing import List, Tuple, Union
import torch.nn.functional as F
import numpy as np import numpy as np
np.seterr(divide='ignore', invalid='ignore') import torch
from torch import nn
from bat_detect.detector.models import ModelOutput
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults
np.seterr(divide="ignore", invalid="ignore")
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): def x_coords_to_time(
nfft = int(fft_win_length*sampling_rate) x_pos: float,
noverlap = int(fft_overlap*nfft) sampling_rate: int,
return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate fft_win_length: float,
#return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window fft_overlap: float,
) -> float:
"""Convert x coordinates of spectrogram to time in seconds.
Args:
x_pos: X position of the detection in pixels.
sampling_rate: Sampling rate of the audio in Hz.
fft_win_length: Length of the FFT window in seconds.
fft_overlap: Overlap of the FFT windows in seconds.
Returns:
Time in seconds.
"""
nfft = int(fft_win_length * sampling_rate)
noverlap = int(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
def overall_class_pred(det_prob, class_prob): def overall_class_pred(det_prob, class_prob):
weighted_pred = (class_prob*det_prob).sum(1) weighted_pred = (class_prob * det_prob).sum(1)
return weighted_pred / weighted_pred.sum() return weighted_pred / weighted_pred.sum()
def run_nms(outputs, params, sampling_rate): def run_nms(
outputs: ModelOutput,
params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
"""Run non-maximum suppression on the output of the model.
pred_det = outputs['pred_det'] # probability of box Model outputs processed are expected to have a batch dimension.
pred_size = outputs['pred_size'] # box size Each element of the batch is processed independently. The
result is a pair of lists, one for the predictions and one for
the features. Each element of the lists corresponds to one
element of the batch.
"""
pred_det, pred_size, pred_class, _, features = outputs
pred_det_nms = non_max_suppression(pred_det, params['nms_kernel_size']) pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
freq_rescale = (params['max_freq'] - params['min_freq']) /pred_det.shape[-2] freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
-2
]
# NOTE there will be small differences depending on which sampling rate is chosen # NOTE: there will be small differences depending on which sampling rate
# as we are choosing the same sampling rate for the entire batch # is chosen as we are choosing the same sampling rate for the entire batch
duration = x_coords_to_time(pred_det.shape[-1], sampling_rate[0].item(), duration = x_coords_to_time(
params['fft_win_length'], params['fft_overlap']) pred_det.shape[-1],
top_k = int(duration * params['nms_top_k_per_sec']) int(sampling_rate[0].item()),
params["fft_win_length"],
params["fft_overlap"],
)
top_k = int(duration * params["nms_top_k_per_sec"])
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs # loop over batch to save outputs
preds = [] preds: List[PredictionResults] = []
feats = [] feats: List[np.ndarray] = []
for ii in range(pred_det_nms.shape[0]): for num_detection in range(pred_det_nms.shape[0]):
# get valid indices # get valid indices
inds_ord = torch.argsort(x_pos[ii, :]) inds_ord = torch.argsort(x_pos[num_detection, :])
valid_inds = scores[ii, inds_ord] > params['detection_threshold'] valid_inds = (
scores[num_detection, inds_ord] > params["detection_threshold"]
)
valid_inds = inds_ord[valid_inds] valid_inds = inds_ord[valid_inds]
# create result dictionary # create result dictionary
pred = {} pred = {}
pred['det_probs'] = scores[ii, valid_inds] pred["det_probs"] = scores[num_detection, valid_inds]
pred['x_pos'] = x_pos[ii, valid_inds] pred["x_pos"] = x_pos[num_detection, valid_inds]
pred['y_pos'] = y_pos[ii, valid_inds] pred["y_pos"] = y_pos[num_detection, valid_inds]
pred['bb_width'] = pred_size[ii, 0, pred['y_pos'], pred['x_pos']] pred["bb_width"] = pred_size[
pred['bb_height'] = pred_size[ii, 1, pred['y_pos'], pred['x_pos']] num_detection,
pred['start_times'] = x_coords_to_time(pred['x_pos'].float() / params['resize_factor'], 0,
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap']) pred["y_pos"],
pred['end_times'] = x_coords_to_time((pred['x_pos'].float()+pred['bb_width']) / params['resize_factor'], pred["x_pos"],
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap']) ]
pred['low_freqs'] = (pred_size[ii].shape[1] - pred['y_pos'].float())*freq_rescale + params['min_freq'] pred["bb_height"] = pred_size[
pred['high_freqs'] = pred['low_freqs'] + pred['bb_height']*freq_rescale num_detection,
1,
pred["y_pos"],
pred["x_pos"],
]
pred["start_times"] = x_coords_to_time(
pred["x_pos"].float() / params["resize_factor"],
int(sampling_rate[num_detection].item()),
params["fft_win_length"],
params["fft_overlap"],
)
pred["end_times"] = x_coords_to_time(
(pred["x_pos"].float() + pred["bb_width"])
/ params["resize_factor"],
int(sampling_rate[num_detection].item()),
params["fft_win_length"],
params["fft_overlap"],
)
pred["low_freqs"] = (
pred_size[num_detection].shape[1] - pred["y_pos"].float()
) * freq_rescale + params["min_freq"]
pred["high_freqs"] = (
pred["low_freqs"] + pred["bb_height"] * freq_rescale
)
# extract the per class votes # extract the per class votes
if 'pred_class' in outputs: if pred_class is not None:
pred['class_probs'] = outputs['pred_class'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]] pred["class_probs"] = pred_class[
num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
]
# extract the model features # extract the model features
if 'features' in outputs: if features is not None:
feat = outputs['features'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]].transpose(0, 1) feat = features[
feat = feat.cpu().numpy().astype(np.float32) num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
].transpose(0, 1)
feat = feat.detach().numpy().astype(np.float32)
feats.append(feat) feats.append(feat)
# convert to numpy # convert to numpy
for kk in pred.keys(): for key, value in pred.items():
pred[kk] = pred[kk].cpu().numpy().astype(np.float32) pred[key] = value.detach().numpy().astype(np.float32)
preds.append(pred)
preds.append(pred) # type: ignore
return preds, feats return preds, feats
def non_max_suppression(heat, kernel_size): def non_max_suppression(
heat: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]],
):
# kernel can be an int or list/tuple # kernel can be an int or list/tuple
if type(kernel_size) is int: if isinstance(kernel_size, int):
kernel_size_h = kernel_size kernel_size_h = kernel_size
kernel_size_w = kernel_size kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2 pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2 pad_w = (kernel_size_w - 1) // 2
hmax = nn.functional.max_pool2d(heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)) hmax = nn.functional.max_pool2d(
heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)
)
keep = (hmax == heat).float() keep = (hmax == heat).float()
return heat * keep return heat * keep
@ -94,7 +174,7 @@ def get_topk_scores(scores, K):
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
topk_inds = topk_inds % (height * width) topk_inds = topk_inds % (height * width)
topk_ys = torch.div(topk_inds, width, rounding_mode='floor').long() topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
topk_xs = (topk_inds % width).long() topk_xs = (topk_inds % width).long()
return topk_scores, topk_ys, topk_xs return topk_scores, topk_ys, topk_xs

View File

View File

@ -2,67 +2,74 @@
Evaluates trained model on test set and generates plots. Evaluates trained model on test set and generates plots.
""" """
import numpy as np import argparse
import sys
import os
import copy import copy
import json import json
import os
import numpy as np
import pandas as pd import pandas as pd
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
import argparse
sys.path.append('../../') from bat_detect.detector import parameters
import bat_detect.utils.detector_utils as du
import bat_detect.train.train_utils as tu
import bat_detect.detector.parameters as parameters
import bat_detect.train.evaluate as evl import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as pu import bat_detect.utils.plot_utils as pu
def get_blank_annotation(ip_str): def get_blank_annotation(ip_str):
res = {} res = {}
res['class_name'] = '' res["class_name"] = ""
res['duration'] = -1 res["duration"] = -1
res['id'] = ''# fileName res["id"] = "" # fileName
res['issues'] = False res["issues"] = False
res['notes'] = ip_str res["notes"] = ip_str
res['time_exp'] = 1 res["time_exp"] = 1
res['annotated'] = False res["annotated"] = False
res['annotation'] = [] res["annotation"] = []
ann = {} ann = {}
ann['class'] = '' ann["class"] = ""
ann['event'] = 'Echolocation' ann["event"] = "Echolocation"
ann['individual'] = -1 ann["individual"] = -1
ann['start_time'] = -1 ann["start_time"] = -1
ann['end_time'] = -1 ann["end_time"] = -1
ann['low_freq'] = -1 ann["low_freq"] = -1
ann['high_freq'] = -1 ann["high_freq"] = -1
ann['confidence'] = -1 ann["confidence"] = -1
return copy.deepcopy(res), copy.deepcopy(ann) return copy.deepcopy(res), copy.deepcopy(ann)
def create_genus_mapping(gt_test, preds, class_names): def create_genus_mapping(gt_test, preds, class_names):
# rolls the per class predictions and ground truth back up to genus level # rolls the per class predictions and ground truth back up to genus level
class_names_genus, cls_to_genus = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True) class_names_genus, cls_to_genus = np.unique(
genus_to_cls_map = [np.where(np.array(cls_to_genus) == cc)[0] for cc in range(len(class_names_genus))] [cc.split(" ")[0] for cc in class_names], return_inverse=True
)
genus_to_cls_map = [
np.where(np.array(cls_to_genus) == cc)[0]
for cc in range(len(class_names_genus))
]
gt_test_g = [] gt_test_g = []
for gg in gt_test: for gg in gt_test:
gg_g = copy.deepcopy(gg) gg_g = copy.deepcopy(gg)
inds = np.where(gg_g['class_ids']!=-1)[0] inds = np.where(gg_g["class_ids"] != -1)[0]
gg_g['class_ids'][inds] = cls_to_genus[gg_g['class_ids'][inds]] gg_g["class_ids"][inds] = cls_to_genus[gg_g["class_ids"][inds]]
gt_test_g.append(gg_g) gt_test_g.append(gg_g)
# note, will have entries geater than one as we are summing across the respective classes # note, will have entries geater than one as we are summing across the respective classes
preds_g = [] preds_g = []
for pp in preds: for pp in preds:
pp_g = copy.deepcopy(pp) pp_g = copy.deepcopy(pp)
pp_g['class_probs'] = np.zeros((len(class_names_genus), pp_g['class_probs'].shape[1]), dtype=np.float32) pp_g["class_probs"] = np.zeros(
(len(class_names_genus), pp_g["class_probs"].shape[1]),
dtype=np.float32,
)
for cc, inds in enumerate(genus_to_cls_map): for cc, inds in enumerate(genus_to_cls_map):
pp_g['class_probs'][cc, :] = pp['class_probs'][inds, :].sum(0) pp_g["class_probs"][cc, :] = pp["class_probs"][inds, :].sum(0)
preds_g.append(pp_g) preds_g.append(pp_g)
return class_names_genus, preds_g, gt_test_g return class_names_genus, preds_g, gt_test_g
@ -70,56 +77,70 @@ def create_genus_mapping(gt_test, preds, class_names):
def load_tadarida_pred(ip_dir, dataset, file_of_interest): def load_tadarida_pred(ip_dir, dataset, file_of_interest):
res, ann = get_blank_annotation('Generated by Tadarida') res, ann = get_blank_annotation("Generated by Tadarida")
# create the annotations in the correct format # create the annotations in the correct format
da_c = pd.read_csv(ip_dir + dataset + '/' + file_of_interest.replace('.wav', '.ta').replace('.WAV', '.ta'), sep='\t') da_c = pd.read_csv(
ip_dir
+ dataset
+ "/"
+ file_of_interest.replace(".wav", ".ta").replace(".WAV", ".ta"),
sep="\t",
)
res_c = copy.deepcopy(res) res_c = copy.deepcopy(res)
res_c['id'] = file_of_interest res_c["id"] = file_of_interest
res_c['dataset'] = dataset res_c["dataset"] = dataset
res_c['feats'] = da_c.iloc[:, 6:].values.astype(np.float32) res_c["feats"] = da_c.iloc[:, 6:].values.astype(np.float32)
if da_c.shape[0] > 0: if da_c.shape[0] > 0:
res_c['class_name'] = '' res_c["class_name"] = ""
res_c['class_prob'] = 0.0 res_c["class_prob"] = 0.0
for aa in range(da_c.shape[0]): for aa in range(da_c.shape[0]):
ann_c = copy.deepcopy(ann) ann_c = copy.deepcopy(ann)
ann_c['class'] = 'Not Bat' # will assign to class later ann_c["class"] = "Not Bat" # will assign to class later
ann_c['start_time'] = np.round(da_c.iloc[aa]['StTime']/1000.0 ,5) ann_c["start_time"] = np.round(da_c.iloc[aa]["StTime"] / 1000.0, 5)
ann_c['end_time'] = np.round((da_c.iloc[aa]['StTime'] + da_c.iloc[aa]['Dur'])/1000.0, 5) ann_c["end_time"] = np.round(
ann_c['low_freq'] = np.round(da_c.iloc[aa]['Fmin'] * 1000.0, 2) (da_c.iloc[aa]["StTime"] + da_c.iloc[aa]["Dur"]) / 1000.0, 5
ann_c['high_freq'] = np.round(da_c.iloc[aa]['Fmax'] * 1000.0, 2) )
ann_c['det_prob'] = 0.0 ann_c["low_freq"] = np.round(da_c.iloc[aa]["Fmin"] * 1000.0, 2)
res_c['annotation'].append(ann_c) ann_c["high_freq"] = np.round(da_c.iloc[aa]["Fmax"] * 1000.0, 2)
ann_c["det_prob"] = 0.0
res_c["annotation"].append(ann_c)
return res_c return res_c
def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_accepted_species=True): def load_sonobat_meta(
ip_dir,
datasets,
region_classifier,
class_names,
only_accepted_species=True,
):
sp_dict = {} sp_dict = {}
for ss in class_names: for ss in class_names:
sp_key = ss.split(' ')[0][:3] + ss.split(' ')[1][:3] sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
sp_dict[sp_key] = ss sp_dict[sp_key] = ss
sp_dict['x'] = '' # not bat sp_dict["x"] = "" # not bat
sp_dict['Bat'] = 'Bat' sp_dict["Bat"] = "Bat"
sonobat_meta = {} sonobat_meta = {}
for tt in datasets: for tt in datasets:
dataset = tt['dataset_name'] dataset = tt["dataset_name"]
sb_ip_dir = ip_dir + dataset + '/' + region_classifier + '/' sb_ip_dir = ip_dir + dataset + "/" + region_classifier + "/"
# load the call level predictions # load the call level predictions
ip_file_p = sb_ip_dir + dataset + '_Parameters_v4.5.0.txt' ip_file_p = sb_ip_dir + dataset + "_Parameters_v4.5.0.txt"
#ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt' # ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt'
da = pd.read_csv(ip_file_p, sep='\t') da = pd.read_csv(ip_file_p, sep="\t")
# load the file level predictions # load the file level predictions
ip_file_b = sb_ip_dir + dataset + '_SonoBatch_v4.5.0.txt' ip_file_b = sb_ip_dir + dataset + "_SonoBatch_v4.5.0.txt"
#ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt' # ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt'
with open(ip_file_b) as f: with open(ip_file_b) as f:
lines = f.readlines() lines = f.readlines()
@ -129,7 +150,7 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
file_res = {} file_res = {}
for ll in lines: for ll in lines:
# note this does not seem to parse the file very well # note this does not seem to parse the file very well
ll_data = ll.split('\t') ll_data = ll.split("\t")
# there are sometimes many different species names per file # there are sometimes many different species names per file
if only_accepted_species: if only_accepted_species:
@ -137,20 +158,24 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
ind = 4 ind = 4
else: else:
# choosing ""~Spp" if "SppAccp" does not exist # choosing ""~Spp" if "SppAccp" does not exist
if ll_data[4] != 'x': if ll_data[4] != "x":
ind = 4 # choosing "SppAccp", along with "Prob" here ind = 4 # choosing "SppAccp", along with "Prob" here
else: else:
ind = 8 # choosing "~Spp", along with "~Prob" here ind = 8 # choosing "~Spp", along with "~Prob" here
sp_name_1 = sp_dict[ll_data[ind]] sp_name_1 = sp_dict[ll_data[ind]]
prob_1 = ll_data[ind+1] prob_1 = ll_data[ind + 1]
if prob_1 == 'x': if prob_1 == "x":
prob_1 = 0.0 prob_1 = 0.0
file_res[ll_data[1]] = {'id':ll_data[1], 'species_1':sp_name_1, 'prob_1':prob_1} file_res[ll_data[1]] = {
"id": ll_data[1],
"species_1": sp_name_1,
"prob_1": prob_1,
}
sonobat_meta[dataset] = {} sonobat_meta[dataset] = {}
sonobat_meta[dataset]['file_res'] = file_res sonobat_meta[dataset]["file_res"] = file_res
sonobat_meta[dataset]['call_info'] = da sonobat_meta[dataset]["call_info"] = da
return sonobat_meta return sonobat_meta
@ -158,34 +183,38 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
# create the annotations in the correct format # create the annotations in the correct format
res, ann = get_blank_annotation('Generated by Sonobat') res, ann = get_blank_annotation("Generated by Sonobat")
res_c = copy.deepcopy(res) res_c = copy.deepcopy(res)
res_c['id'] = id res_c["id"] = id
res_c['dataset'] = dataset res_c["dataset"] = dataset
da = sb_meta[dataset]['call_info'] da = sb_meta[dataset]["call_info"]
da_c = da[da['Filename'] == id] da_c = da[da["Filename"] == id]
file_res = sb_meta[dataset]['file_res'] file_res = sb_meta[dataset]["file_res"]
res_c['feats'] = np.zeros((0,0)) res_c["feats"] = np.zeros((0, 0))
if da_c.shape[0] > 0: if da_c.shape[0] > 0:
res_c['class_name'] = file_res[id]['species_1'] res_c["class_name"] = file_res[id]["species_1"]
res_c['class_prob'] = file_res[id]['prob_1'] res_c["class_prob"] = file_res[id]["prob_1"]
res_c['feats'] = da_c.iloc[:, 3:105].values.astype(np.float32) res_c["feats"] = da_c.iloc[:, 3:105].values.astype(np.float32)
for aa in range(da_c.shape[0]): for aa in range(da_c.shape[0]):
ann_c = copy.deepcopy(ann) ann_c = copy.deepcopy(ann)
if set_class_name is None: if set_class_name is None:
ann_c['class'] = file_res[id]['species_1'] ann_c["class"] = file_res[id]["species_1"]
else: else:
ann_c['class'] = set_class_name ann_c["class"] = set_class_name
ann_c['start_time'] = np.round(da_c.iloc[aa]['TimeInFile'] / 1000.0 ,5) ann_c["start_time"] = np.round(
ann_c['end_time'] = np.round(ann_c['start_time'] + da_c.iloc[aa]['CallDuration']/1000.0, 5) da_c.iloc[aa]["TimeInFile"] / 1000.0, 5
ann_c['low_freq'] = np.round(da_c.iloc[aa]['LowFreq'] * 1000.0, 2) )
ann_c['high_freq'] = np.round(da_c.iloc[aa]['HiFreq'] * 1000.0, 2) ann_c["end_time"] = np.round(
ann_c['det_prob'] = np.round(da_c.iloc[aa]['Quality'], 3) ann_c["start_time"] + da_c.iloc[aa]["CallDuration"] / 1000.0, 5
res_c['annotation'].append(ann_c) )
ann_c["low_freq"] = np.round(da_c.iloc[aa]["LowFreq"] * 1000.0, 2)
ann_c["high_freq"] = np.round(da_c.iloc[aa]["HiFreq"] * 1000.0, 2)
ann_c["det_prob"] = np.round(da_c.iloc[aa]["Quality"], 3)
res_c["annotation"].append(ann_c)
return res_c return res_c
@ -193,8 +222,18 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
def bb_overlap(bb_g_in, bb_p_in): def bb_overlap(bb_g_in, bb_p_in):
freq_scale = 10000000.0 # ensure that both axis are roughly the same range freq_scale = 10000000.0 # ensure that both axis are roughly the same range
bb_g = [bb_g_in['start_time'], bb_g_in['low_freq']/freq_scale, bb_g_in['end_time'], bb_g_in['high_freq']/freq_scale] bb_g = [
bb_p = [bb_p_in['start_time'], bb_p_in['low_freq']/freq_scale, bb_p_in['end_time'], bb_p_in['high_freq']/freq_scale] bb_g_in["start_time"],
bb_g_in["low_freq"] / freq_scale,
bb_g_in["end_time"],
bb_g_in["high_freq"] / freq_scale,
]
bb_p = [
bb_p_in["start_time"],
bb_p_in["low_freq"] / freq_scale,
bb_p_in["end_time"],
bb_p_in["high_freq"] / freq_scale,
]
xA = max(bb_g[0], bb_p[0]) xA = max(bb_g[0], bb_p[0])
yA = max(bb_g[1], bb_p[1]) yA = max(bb_g[1], bb_p[1])
@ -220,13 +259,15 @@ def bb_overlap(bb_g_in, bb_p_in):
def assign_to_gt(gt, pred, iou_thresh): def assign_to_gt(gt, pred, iou_thresh):
# this will edit pred in place # this will edit pred in place
num_preds = len(pred['annotation']) num_preds = len(pred["annotation"])
num_gts = len(gt['annotation']) num_gts = len(gt["annotation"])
if num_preds > 0 and num_gts > 0: if num_preds > 0 and num_gts > 0:
iou_m = np.zeros((num_preds, num_gts)) iou_m = np.zeros((num_preds, num_gts))
for ii in range(num_preds): for ii in range(num_preds):
for jj in range(num_gts): for jj in range(num_gts):
iou_m[ii, jj] = bb_overlap(gt['annotation'][jj], pred['annotation'][ii]) iou_m[ii, jj] = bb_overlap(
gt["annotation"][jj], pred["annotation"][ii]
)
# greedily assign detections to ground truths # greedily assign detections to ground truths
# needs to be greater than some threshold and we cannot assign GT # needs to be greater than some threshold and we cannot assign GT
@ -235,7 +276,9 @@ def assign_to_gt(gt, pred, iou_thresh):
for jj in range(num_gts): for jj in range(num_gts):
max_iou = np.argmax(iou_m[:, jj]) max_iou = np.argmax(iou_m[:, jj])
if iou_m[max_iou, jj] > iou_thresh: if iou_m[max_iou, jj] > iou_thresh:
pred['annotation'][max_iou]['class'] = gt['annotation'][jj]['class'] pred["annotation"][max_iou]["class"] = gt["annotation"][jj][
"class"
]
iou_m[max_iou, :] = -1.0 iou_m[max_iou, :] = -1.0
return pred return pred
@ -244,27 +287,39 @@ def assign_to_gt(gt, pred, iou_thresh):
def parse_data(data, class_names, non_event_classes, is_pred=False): def parse_data(data, class_names, non_event_classes, is_pred=False):
class_names_all = class_names + non_event_classes class_names_all = class_names + non_event_classes
data['class_names'] = np.array([aa['class'] for aa in data['annotation']]) data["class_names"] = np.array([aa["class"] for aa in data["annotation"]])
data['start_times'] = np.array([aa['start_time'] for aa in data['annotation']]) data["start_times"] = np.array(
data['end_times'] = np.array([aa['end_time'] for aa in data['annotation']]) [aa["start_time"] for aa in data["annotation"]]
data['high_freqs'] = np.array([float(aa['high_freq']) for aa in data['annotation']]) )
data['low_freqs'] = np.array([float(aa['low_freq']) for aa in data['annotation']]) data["end_times"] = np.array([aa["end_time"] for aa in data["annotation"]])
data["high_freqs"] = np.array(
[float(aa["high_freq"]) for aa in data["annotation"]]
)
data["low_freqs"] = np.array(
[float(aa["low_freq"]) for aa in data["annotation"]]
)
if is_pred: if is_pred:
# when loading predictions # when loading predictions
data['det_probs'] = np.array([float(aa['det_prob']) for aa in data['annotation']]) data["det_probs"] = np.array(
data['class_probs'] = np.zeros((len(class_names)+1, len(data['annotation']))) [float(aa["det_prob"]) for aa in data["annotation"]]
data['class_ids'] = np.array([class_names_all.index(aa['class']) for aa in data['annotation']]).astype(np.int32) )
data["class_probs"] = np.zeros(
(len(class_names) + 1, len(data["annotation"]))
)
data["class_ids"] = np.array(
[class_names_all.index(aa["class"]) for aa in data["annotation"]]
).astype(np.int32)
else: else:
# when loading ground truth # when loading ground truth
# if the class label is not in the set of interest then set to -1 # if the class label is not in the set of interest then set to -1
labels = [] labels = []
for aa in data['annotation']: for aa in data["annotation"]:
if aa['class'] in class_names: if aa["class"] in class_names:
labels.append(class_names_all.index(aa['class'])) labels.append(class_names_all.index(aa["class"]))
else: else:
labels.append(-1) labels.append(-1)
data['class_ids'] = np.array(labels).astype(np.int32) data["class_ids"] = np.array(labels).astype(np.int32)
return data return data
@ -272,12 +327,17 @@ def parse_data(data, class_names, non_event_classes, is_pred=False):
def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore): def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
gt_data = [] gt_data = []
for dd in datasets: for dd in datasets:
print('\n' + dd['dataset_name']) print("\n" + dd["dataset_name"])
gt_dataset = tu.load_set_of_anns([dd], events_of_interest=events_of_interest, verbose=True) gt_dataset = tu.load_set_of_anns(
gt_dataset = [parse_data(gg, class_names, classes_to_ignore, False) for gg in gt_dataset] [dd], events_of_interest=events_of_interest, verbose=True
)
gt_dataset = [
parse_data(gg, class_names, classes_to_ignore, False)
for gg in gt_dataset
]
for gt in gt_dataset: for gt in gt_dataset:
gt['dataset_name'] = dd['dataset_name'] gt["dataset_name"] = dd["dataset_name"]
gt_data.extend(gt_dataset) gt_data.extend(gt_dataset)
@ -300,69 +360,103 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
clf = RandomForestClassifier(random_state=seed, n_jobs=-1) clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
clf.fit(x_train, y_train) clf.fit(x_train, y_train)
y_pred = clf.predict(x_train) y_pred = clf.predict(x_train)
tr_acc = (y_pred==y_train).mean() tr_acc = (y_pred == y_train).mean()
#print('Train acc', round(tr_acc*100, 2)) # print('Train acc', round(tr_acc*100, 2))
return clf, un_train_class return clf, un_train_class
def eval_rf_model(clf, pred, un_train_class, num_classes): def eval_rf_model(clf, pred, un_train_class, num_classes):
# stores the prediction in place # stores the prediction in place
if pred['feats'].shape[0] > 0: if pred["feats"].shape[0] > 0:
pred['class_probs'] = np.zeros((num_classes, pred['feats'].shape[0])) pred["class_probs"] = np.zeros((num_classes, pred["feats"].shape[0]))
pred['class_probs'][un_train_class, :] = clf.predict_proba(pred['feats']).T pred["class_probs"][un_train_class, :] = clf.predict_proba(
pred['det_probs'] = pred['class_probs'][:-1, :].sum(0) pred["feats"]
).T
pred["det_probs"] = pred["class_probs"][:-1, :].sum(0)
else: else:
pred['class_probs'] = np.zeros((num_classes, 0)) pred["class_probs"] = np.zeros((num_classes, 0))
pred['det_probs'] = np.zeros(0) pred["det_probs"] = np.zeros(0)
return pred return pred
def save_summary_to_json(op_dir, mod_name, results): def save_summary_to_json(op_dir, mod_name, results):
op = {} op = {}
op['avg_prec'] = round(results['avg_prec'], 3) op["avg_prec"] = round(results["avg_prec"], 3)
op['avg_prec_class'] = round(results['avg_prec_class'], 3) op["avg_prec_class"] = round(results["avg_prec_class"], 3)
op['top_class'] = round(results['top_class']['avg_prec'], 3) op["top_class"] = round(results["top_class"]["avg_prec"], 3)
op['file_acc'] = round(results['file_acc'], 3) op["file_acc"] = round(results["file_acc"], 3)
op['model'] = mod_name op["model"] = mod_name
op['per_class'] = {} op["per_class"] = {}
for cc in results['class_pr']: for cc in results["class_pr"]:
op['per_class'][cc['name']] = cc['avg_prec'] op["per_class"][cc["name"]] = cc["avg_prec"]
op_file_name = os.path.join(op_dir, mod_name + '_results.json') op_file_name = os.path.join(op_dir, mod_name + "_results.json")
with open(op_file_name, 'w') as da: with open(op_file_name, "w") as da:
json.dump(op, da, indent=2) json.dump(op, da, indent=2)
def print_results(model_name, mod_str, results, op_dir, class_names, file_type, title_text=''): def print_results(
print('\nResults - ' + model_name) model_name, mod_str, results, op_dir, class_names, file_type, title_text=""
print('avg_prec ', round(results['avg_prec'], 3)) ):
print('avg_prec_class', round(results['avg_prec_class'], 3)) print("\nResults - " + model_name)
print('top_class ', round(results['top_class']['avg_prec'], 3)) print("avg_prec ", round(results["avg_prec"], 3))
print('file_acc ', round(results['file_acc'], 3)) print("avg_prec_class", round(results["avg_prec_class"], 3))
print("top_class ", round(results["top_class"]["avg_prec"], 3))
print("file_acc ", round(results["file_acc"], 3))
print('\nSaving ' + model_name + ' results to: ' + op_dir) print("\nSaving " + model_name + " results to: " + op_dir)
save_summary_to_json(op_dir, mod_str, results) save_summary_to_json(op_dir, mod_str, results)
pu.plot_pr_curve(op_dir, mod_str+'_test_all_det', mod_str+'_test_all_det', results, file_type, title_text + 'Detection PR') pu.plot_pr_curve(
pu.plot_pr_curve(op_dir, mod_str+'_test_all_top_class', mod_str+'_test_all_top_class', results['top_class'], file_type, title_text + 'Top Class') op_dir,
pu.plot_pr_curve_class(op_dir, mod_str+'_test_all_class', mod_str+'_test_all_class', results, file_type, title_text + 'Per-Class PR') mod_str + "_test_all_det",
pu.plot_confusion_matrix(op_dir, mod_str+'_confusion', results['gt_valid_file'], results['pred_valid_file'], mod_str + "_test_all_det",
results['file_acc'], class_names, True, file_type, title_text + 'Confusion Matrix') results,
file_type,
title_text + "Detection PR",
)
pu.plot_pr_curve(
op_dir,
mod_str + "_test_all_top_class",
mod_str + "_test_all_top_class",
results["top_class"],
file_type,
title_text + "Top Class",
)
pu.plot_pr_curve_class(
op_dir,
mod_str + "_test_all_class",
mod_str + "_test_all_class",
results,
file_type,
title_text + "Per-Class PR",
)
pu.plot_confusion_matrix(
op_dir,
mod_str + "_confusion",
results["gt_valid_file"],
results["pred_valid_file"],
results["file_acc"],
class_names,
True,
file_type,
title_text + "Confusion Matrix",
)
def add_root_path_back(data_sets, ann_path, wav_path): def add_root_path_back(data_sets, ann_path, wav_path):
for dd in data_sets: for dd in data_sets:
dd['ann_path'] = os.path.join(ann_path, dd['ann_path']) dd["ann_path"] = os.path.join(ann_path, dd["ann_path"])
dd['wav_path'] = os.path.join(wav_path, dd['wav_path']) dd["wav_path"] = os.path.join(wav_path, dd["wav_path"])
return data_sets return data_sets
def check_classes_in_train(gt_list, class_names): def check_classes_in_train(gt_list, class_names):
num_gt_total = np.sum([gg['start_times'].shape[0] for gg in gt_list]) num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
num_with_no_class = 0 num_with_no_class = 0
for gt in gt_list: for gt in gt_list:
for cc in gt['class_names']: for cc in gt["class_names"]:
if cc not in class_names: if cc not in class_names:
num_with_no_class += 1 num_with_no_class += 1
return num_with_no_class return num_with_no_class
@ -371,195 +465,337 @@ def check_classes_in_train(gt_list, class_names):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('op_dir', type=str, default='plots/results_compare/', parser.add_argument(
help='Output directory for plots') "op_dir",
parser.add_argument('data_dir', type=str, type=str,
help='Path to root of datasets') default="plots/results_compare/",
parser.add_argument('ann_dir', type=str, help="Output directory for plots",
help='Path to extracted annotations') )
parser.add_argument('bd_model_path', type=str, parser.add_argument("data_dir", type=str, help="Path to root of datasets")
help='Path to BatDetect model') parser.add_argument(
parser.add_argument('--test_file', type=str, default='', "ann_dir", type=str, help="Path to extracted annotations"
help='Path to json file used for evaluation.') )
parser.add_argument('--sb_ip_dir', type=str, default='', parser.add_argument(
help='Path to sonobat predictions') "bd_model_path", type=str, help="Path to BatDetect model"
parser.add_argument('--sb_region_classifier', type=str, default='south', )
help='Path to sonobat predictions') parser.add_argument(
parser.add_argument('--td_ip_dir', type=str, default='', "--test_file",
help='Path to tadarida_D predictions') type=str,
parser.add_argument('--iou_thresh', type=float, default=0.01, default="",
help='IOU threshold for assigning predictions to ground truth') help="Path to json file used for evaluation.",
parser.add_argument('--file_type', type=str, default='png', )
help='Type of image to save - png or pdf') parser.add_argument(
parser.add_argument('--title_text', type=str, default='', "--sb_ip_dir", type=str, default="", help="Path to sonobat predictions"
help='Text to add as title of plots') )
parser.add_argument('--rand_seed', type=int, default=2001, parser.add_argument(
help='Random seed') "--sb_region_classifier",
type=str,
default="south",
help="Path to sonobat predictions",
)
parser.add_argument(
"--td_ip_dir",
type=str,
default="",
help="Path to tadarida_D predictions",
)
parser.add_argument(
"--iou_thresh",
type=float,
default=0.01,
help="IOU threshold for assigning predictions to ground truth",
)
parser.add_argument(
"--file_type",
type=str,
default="png",
help="Type of image to save - png or pdf",
)
parser.add_argument(
"--title_text",
type=str,
default="",
help="Text to add as title of plots",
)
parser.add_argument(
"--rand_seed", type=int, default=2001, help="Random seed"
)
args = vars(parser.parse_args()) args = vars(parser.parse_args())
np.random.seed(args['rand_seed']) np.random.seed(args["rand_seed"])
if not os.path.isdir(args['op_dir']):
os.makedirs(args['op_dir'])
if not os.path.isdir(args["op_dir"]):
os.makedirs(args["op_dir"])
# load the model # load the model
params_eval = parameters.get_params(False) params_eval = parameters.get_params(False)
_, params_bd = du.load_model(args['bd_model_path']) _, params_bd = du.load_model(args["bd_model_path"])
class_names = params_bd['class_names'] class_names = params_bd["class_names"]
num_classes = len(class_names) + 1 # num classes plus background class num_classes = len(class_names) + 1 # num classes plus background class
classes_to_ignore = ['Not Bat', 'Bat', 'Unknown'] classes_to_ignore = ["Not Bat", "Bat", "Unknown"]
events_of_interest = ['Echolocation'] events_of_interest = ["Echolocation"]
# load test data # load test data
if args['test_file'] == '': if args["test_file"] == "":
# load the test files of interest from the trained model # load the test files of interest from the trained model
test_sets = add_root_path_back(params_bd['test_sets'], args['ann_dir'], args['data_dir']) test_sets = add_root_path_back(
test_sets = [dd for dd in test_sets if not dd['is_binary']] # exclude bat/not datasets params_bd["test_sets"], args["ann_dir"], args["data_dir"]
)
test_sets = [
dd for dd in test_sets if not dd["is_binary"]
] # exclude bat/not datasets
else: else:
# user specified annotation file to evaluate # user specified annotation file to evaluate
test_dict = {} test_dict = {}
test_dict['dataset_name'] = args['test_file'].replace('.json', '') test_dict["dataset_name"] = args["test_file"].replace(".json", "")
test_dict['is_test'] = True test_dict["is_test"] = True
test_dict['is_binary'] = True test_dict["is_binary"] = True
test_dict['ann_path'] = os.path.join(args['ann_dir'], args['test_file']) test_dict["ann_path"] = os.path.join(args["ann_dir"], args["test_file"])
test_dict['wav_path'] = args['data_dir'] test_dict["wav_path"] = args["data_dir"]
test_sets = [test_dict] test_sets = [test_dict]
# load the gt for the test set # load the gt for the test set
gt_test = load_gt_data(test_sets, events_of_interest, class_names, classes_to_ignore) gt_test = load_gt_data(
total_num_calls = np.sum([gg['start_times'].shape[0] for gg in gt_test]) test_sets, events_of_interest, class_names, classes_to_ignore
print('\nTotal number of test files:', len(gt_test)) )
print('Total number of test calls:', np.sum([gg['start_times'].shape[0] for gg in gt_test])) total_num_calls = np.sum([gg["start_times"].shape[0] for gg in gt_test])
print("\nTotal number of test files:", len(gt_test))
print(
"Total number of test calls:",
np.sum([gg["start_times"].shape[0] for gg in gt_test]),
)
# check if test contains classes not in the train set # check if test contains classes not in the train set
num_with_no_class = check_classes_in_train(gt_test, class_names) num_with_no_class = check_classes_in_train(gt_test, class_names)
if total_num_calls == num_with_no_class: if total_num_calls == num_with_no_class:
print('Classes from the test set are not in the train set.') print("Classes from the test set are not in the train set.")
assert False assert False
# only need the train data if evaluating Sonobat or Tadarida # only need the train data if evaluating Sonobat or Tadarida
if args['sb_ip_dir'] != '' or args['td_ip_dir'] != '': if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
train_sets = add_root_path_back(params_bd['train_sets'], args['ann_dir'], args['data_dir']) train_sets = add_root_path_back(
train_sets = [dd for dd in train_sets if not dd['is_binary']] # exclude bat/not datasets params_bd["train_sets"], args["ann_dir"], args["data_dir"]
gt_train = load_gt_data(train_sets, events_of_interest, class_names, classes_to_ignore) )
train_sets = [
dd for dd in train_sets if not dd["is_binary"]
] # exclude bat/not datasets
gt_train = load_gt_data(
train_sets, events_of_interest, class_names, classes_to_ignore
)
# #
# evaluate Sonobat by training random forest classifier # evaluate Sonobat by training random forest classifier
# #
# NOTE: Sonobat may only make predictions for a subset of the files # NOTE: Sonobat may only make predictions for a subset of the files
# #
if args['sb_ip_dir'] != '': if args["sb_ip_dir"] != "":
sb_meta = load_sonobat_meta(args['sb_ip_dir'], train_sets + test_sets, args['sb_region_classifier'], class_names) sb_meta = load_sonobat_meta(
args["sb_ip_dir"],
train_sets + test_sets,
args["sb_region_classifier"],
class_names,
)
preds_sb = [] preds_sb = []
keep_inds_sb = [] keep_inds_sb = []
for ii, gt in enumerate(gt_test): for ii, gt in enumerate(gt_test):
sb_pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta) sb_pred = load_sonobat_preds(gt["dataset_name"], gt["id"], sb_meta)
if sb_pred['class_name'] != '': if sb_pred["class_name"] != "":
sb_pred = parse_data(sb_pred, class_names, classes_to_ignore, True) sb_pred = parse_data(
sb_pred['class_probs'][sb_pred['class_ids'], np.arange(sb_pred['class_probs'].shape[1])] = sb_pred['det_probs'] sb_pred, class_names, classes_to_ignore, True
)
sb_pred["class_probs"][
sb_pred["class_ids"],
np.arange(sb_pred["class_probs"].shape[1]),
] = sb_pred["det_probs"]
preds_sb.append(sb_pred) preds_sb.append(sb_pred)
keep_inds_sb.append(ii) keep_inds_sb.append(ii)
results_sb = evl.evaluate_predictions([gt_test[ii] for ii in keep_inds_sb], preds_sb, class_names, results_sb = evl.evaluate_predictions(
params_eval['detection_overlap'], params_eval['ignore_start_end']) [gt_test[ii] for ii in keep_inds_sb],
print_results('Sonobat', 'sb', results_sb, args['op_dir'], class_names, preds_sb,
args['file_type'], args['title_text'] + ' - Species - ') class_names,
print('Only reporting results for', len(keep_inds_sb), 'files, out of', len(gt_test)) params_eval["detection_overlap"],
params_eval["ignore_start_end"],
)
print_results(
"Sonobat",
"sb",
results_sb,
args["op_dir"],
class_names,
args["file_type"],
args["title_text"] + " - Species - ",
)
print(
"Only reporting results for",
len(keep_inds_sb),
"files, out of",
len(gt_test),
)
# train our own random forest on sonobat features # train our own random forest on sonobat features
x_train = [] x_train = []
y_train = [] y_train = []
for gt in gt_train: for gt in gt_train:
pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat') pred = load_sonobat_preds(
gt["dataset_name"], gt["id"], sb_meta, "Not Bat"
)
if len(pred['annotation']) > 0: if len(pred["annotation"]) > 0:
# compute detection overlap with ground truth to determine which are the TP detections # compute detection overlap with ground truth to determine which are the TP detections
assign_to_gt(gt, pred, args['iou_thresh']) assign_to_gt(gt, pred, args["iou_thresh"])
pred = parse_data(pred, class_names, classes_to_ignore, True) pred = parse_data(pred, class_names, classes_to_ignore, True)
x_train.append(pred['feats']) x_train.append(pred["feats"])
y_train.append(pred['class_ids']) y_train.append(pred["class_ids"])
# train random forest on tadarida predictions # train random forest on tadarida predictions
clf_sb, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed']) clf_sb, un_train_class = train_rf_model(
x_train, y_train, num_classes, args["rand_seed"]
)
# run the model on the test set # run the model on the test set
preds_sb_rf = [] preds_sb_rf = []
for gt in gt_test: for gt in gt_test:
pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat') pred = load_sonobat_preds(
gt["dataset_name"], gt["id"], sb_meta, "Not Bat"
)
pred = parse_data(pred, class_names, classes_to_ignore, True) pred = parse_data(pred, class_names, classes_to_ignore, True)
pred = eval_rf_model(clf_sb, pred, un_train_class, num_classes) pred = eval_rf_model(clf_sb, pred, un_train_class, num_classes)
preds_sb_rf.append(pred) preds_sb_rf.append(pred)
results_sb_rf = evl.evaluate_predictions(gt_test, preds_sb_rf, class_names, results_sb_rf = evl.evaluate_predictions(
params_eval['detection_overlap'], params_eval['ignore_start_end']) gt_test,
print_results('Sonobat RF', 'sb_rf', results_sb_rf, args['op_dir'], class_names, preds_sb_rf,
args['file_type'], args['title_text'] + ' - Species - ') class_names,
print('\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n') params_eval["detection_overlap"],
params_eval["ignore_start_end"],
)
print_results(
"Sonobat RF",
"sb_rf",
results_sb_rf,
args["op_dir"],
class_names,
args["file_type"],
args["title_text"] + " - Species - ",
)
print(
"\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n"
)
# #
# evaluate Tadarida-D by training random forest classifier # evaluate Tadarida-D by training random forest classifier
# #
if args['td_ip_dir'] != '': if args["td_ip_dir"] != "":
x_train = [] x_train = []
y_train = [] y_train = []
for gt in gt_train: for gt in gt_train:
pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id']) pred = load_tadarida_pred(
args["td_ip_dir"], gt["dataset_name"], gt["id"]
)
# compute detection overlap with ground truth to determine which are the TP detections # compute detection overlap with ground truth to determine which are the TP detections
assign_to_gt(gt, pred, args['iou_thresh']) assign_to_gt(gt, pred, args["iou_thresh"])
pred = parse_data(pred, class_names, classes_to_ignore, True) pred = parse_data(pred, class_names, classes_to_ignore, True)
x_train.append(pred['feats']) x_train.append(pred["feats"])
y_train.append(pred['class_ids']) y_train.append(pred["class_ids"])
# train random forest on Tadarida-D predictions # train random forest on Tadarida-D predictions
clf_td, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed']) clf_td, un_train_class = train_rf_model(
x_train, y_train, num_classes, args["rand_seed"]
)
# run the model on the test set # run the model on the test set
preds_td = [] preds_td = []
for gt in gt_test: for gt in gt_test:
pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id']) pred = load_tadarida_pred(
args["td_ip_dir"], gt["dataset_name"], gt["id"]
)
pred = parse_data(pred, class_names, classes_to_ignore, True) pred = parse_data(pred, class_names, classes_to_ignore, True)
pred = eval_rf_model(clf_td, pred, un_train_class, num_classes) pred = eval_rf_model(clf_td, pred, un_train_class, num_classes)
preds_td.append(pred) preds_td.append(pred)
results_td = evl.evaluate_predictions(gt_test, preds_td, class_names, results_td = evl.evaluate_predictions(
params_eval['detection_overlap'], params_eval['ignore_start_end']) gt_test,
print_results('Tadarida', 'td_rf', results_td, args['op_dir'], class_names, preds_td,
args['file_type'], args['title_text'] + ' - Species - ') class_names,
params_eval["detection_overlap"],
params_eval["ignore_start_end"],
)
print_results(
"Tadarida",
"td_rf",
results_td,
args["op_dir"],
class_names,
args["file_type"],
args["title_text"] + " - Species - ",
)
# #
# evaluate BatDetect # evaluate BatDetect
# #
if args['bd_model_path'] != '': if args["bd_model_path"] != "":
# load model # load model
bd_args = du.get_default_bd_args() bd_args = du.get_default_run_config()
model, params_bd = du.load_model(args['bd_model_path']) model, params_bd = du.load_model(args["bd_model_path"])
# check if the class names are the same # check if the class names are the same
if params_bd['class_names'] != class_names: if params_bd["class_names"] != class_names:
print('Warning: Class names are not the same as the trained model') print("Warning: Class names are not the same as the trained model")
assert False assert False
run_config = {
**bd_args,
**params_bd,
"return_raw_preds": True,
}
preds_bd = [] preds_bd = []
for ii, gg in enumerate(gt_test): for ii, gg in enumerate(gt_test):
pred = du.process_file(gg['file_path'], model, params_bd, bd_args, return_raw_preds=True) pred = du.process_file(
gg["file_path"],
model,
run_config,
)
preds_bd.append(pred) preds_bd.append(pred)
results_bd = evl.evaluate_predictions(gt_test, preds_bd, class_names, results_bd = evl.evaluate_predictions(
params_eval['detection_overlap'], params_eval['ignore_start_end']) gt_test,
print_results('BatDetect', 'bd', results_bd, args['op_dir'], preds_bd,
class_names, args['file_type'], args['title_text'] + ' - Species - ') class_names,
params_eval["detection_overlap"],
params_eval["ignore_start_end"],
)
print_results(
"BatDetect",
"bd",
results_bd,
args["op_dir"],
class_names,
args["file_type"],
args["title_text"] + " - Species - ",
)
# evaluate genus level # evaluate genus level
class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(gt_test, preds_bd, class_names) class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(
results_bd_genus = evl.evaluate_predictions(gt_test_g, preds_bd_g, class_names_genus, gt_test, preds_bd, class_names
params_eval['detection_overlap'], params_eval['ignore_start_end']) )
print_results('BatDetect Genus', 'bd_genus', results_bd_genus, args['op_dir'], results_bd_genus = evl.evaluate_predictions(
class_names_genus, args['file_type'], args['title_text'] + ' - Genus - ') gt_test_g,
preds_bd_g,
class_names_genus,
params_eval["detection_overlap"],
params_eval["ignore_start_end"],
)
print_results(
"BatDetect Genus",
"bd_genus",
results_bd_genus,
args["op_dir"],
class_names_genus,
args["file_type"],
args["title_text"] + " - Genus - ",
)

View File

View File

@ -1,183 +1,321 @@
import numpy as np import argparse
import matplotlib.pyplot as plt import glob
import json
import os import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import argparse
import glob
import sys sys.path.append(os.path.join("..", ".."))
sys.path.append(os.path.join('..', '..')) import bat_detect.detector.models as models
import bat_detect.train.train_model as tm import bat_detect.detector.parameters as parameters
import bat_detect.detector.post_process as pp
import bat_detect.train.audio_dataloader as adl import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.train.losses as losses import bat_detect.train.losses as losses
import bat_detect.train.train_model as tm
import bat_detect.detector.parameters as parameters import bat_detect.train.train_utils as tu
import bat_detect.detector.models as models
import bat_detect.detector.post_process as pp
import bat_detect.utils.plot_utils as pu
import bat_detect.utils.detector_utils as du import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as pu
if __name__ == "__main__": if __name__ == "__main__":
info_str = '\nBatDetect - Finetune Model\n' info_str = "\nBatDetect - Finetune Model\n"
print(info_str) print(info_str)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('audio_path', type=str, help='Input directory for audio') parser.add_argument(
parser.add_argument('train_ann_path', type=str, "audio_path", type=str, help="Input directory for audio"
help='Path to where train annotation file is stored') )
parser.add_argument('test_ann_path', type=str, parser.add_argument(
help='Path to where test annotation file is stored') "train_ann_path",
parser.add_argument('model_path', type=str, type=str,
help='Path to pretrained model') help="Path to where train annotation file is stored",
parser.add_argument('--op_model_name', type=str, default='', )
help='Path and name for finetuned model') parser.add_argument(
parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs', "test_ann_path",
help='Number of finetuning epochs') type=str,
parser.add_argument('--finetune_only_last_layer', action='store_true', help="Path to where test annotation file is stored",
help='Only train final layers') )
parser.add_argument('--train_from_scratch', action='store_true', parser.add_argument("model_path", type=str, help="Path to pretrained model")
help='Do not use pretrained weights') parser.add_argument(
parser.add_argument('--do_not_save_images', action='store_false', "--op_model_name",
help='Do not save images at the end of training') type=str,
parser.add_argument('--notes', type=str, default='', default="",
help='Notes to save in text file') help="Path and name for finetuned model",
)
parser.add_argument(
"--num_epochs",
type=int,
default=200,
dest="num_epochs",
help="Number of finetuning epochs",
)
parser.add_argument(
"--finetune_only_last_layer",
action="store_true",
help="Only train final layers",
)
parser.add_argument(
"--train_from_scratch",
action="store_true",
help="Do not use pretrained weights",
)
parser.add_argument(
"--do_not_save_images",
action="store_false",
help="Do not save images at the end of training",
)
parser.add_argument(
"--notes", type=str, default="", help="Notes to save in text file"
)
args = vars(parser.parse_args()) args = vars(parser.parse_args())
params = parameters.get_params(True, '../../experiments/') params = parameters.get_params(True, "../../experiments/")
if torch.cuda.is_available(): if torch.cuda.is_available():
params['device'] = 'cuda' params["device"] = "cuda"
else: else:
params['device'] = 'cpu' params["device"] = "cpu"
print('\nNote, this will be a lot faster if you use computer with a GPU.\n') print(
"\nNote, this will be a lot faster if you use computer with a GPU.\n"
)
print('\nAudio directory: ' + args['audio_path']) print("\nAudio directory: " + args["audio_path"])
print('Train file: ' + args['train_ann_path']) print("Train file: " + args["train_ann_path"])
print('Test file: ' + args['test_ann_path']) print("Test file: " + args["test_ann_path"])
print('Loading model: ' + args['model_path']) print("Loading model: " + args["model_path"])
dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '') dataset_name = (
os.path.basename(args["train_ann_path"])
.replace(".json", "")
.replace("_TRAIN", "")
)
if args['train_from_scratch']: if args["train_from_scratch"]:
print('\nTraining model from scratch i.e. not using pretrained weights') print("\nTraining model from scratch i.e. not using pretrained weights")
model, params_train = du.load_model(args['model_path'], False) model, params_train = du.load_model(args["model_path"], False)
else: else:
model, params_train = du.load_model(args['model_path'], True) model, params_train = du.load_model(args["model_path"], True)
model.to(params['device']) model.to(params["device"])
params['num_epochs'] = args['num_epochs'] params["num_epochs"] = args["num_epochs"]
if args['op_model_name'] != '': if args["op_model_name"] != "":
params['model_file_name'] = args['op_model_name'] params["model_file_name"] = args["op_model_name"]
classes_to_ignore = params['classes_to_ignore']+params['generic_class'] classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# save notes file # save notes file
params['notes'] = args['notes'] params["notes"] = args["notes"]
if args['notes'] != '': if args["notes"] != "":
tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes']) tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
# load train annotations # load train annotations
train_sets = [] train_sets = []
train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path'])) train_sets.append(
params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])] tu.get_blank_dataset_dict(
dataset_name, False, args["train_ann_path"], args["audio_path"]
)
)
params["train_sets"] = [
tu.get_blank_dataset_dict(
dataset_name,
False,
os.path.basename(args["train_ann_path"]),
args["audio_path"],
)
]
print('\nTrain set:') print("\nTrain set:")
data_train, params['class_names'], params['class_inv_freq'] = \ (
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest']) data_train,
print('Number of files', len(data_train)) params["class_names"],
params["class_inv_freq"],
) = tu.load_set_of_anns(
train_sets, classes_to_ignore, params["events_of_interest"]
)
print("Number of files", len(data_train))
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names']) params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params['class_names_short'] = tu.get_short_class_names(params['class_names']) params["class_names"]
)
params["class_names_short"] = tu.get_short_class_names(
params["class_names"]
)
# load test annotations # load test annotations
test_sets = [] test_sets = []
test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path'])) test_sets.append(
params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])] tu.get_blank_dataset_dict(
dataset_name, True, args["test_ann_path"], args["audio_path"]
)
)
params["test_sets"] = [
tu.get_blank_dataset_dict(
dataset_name,
True,
os.path.basename(args["test_ann_path"]),
args["audio_path"],
)
]
print('\nTest set:') print("\nTest set:")
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest']) data_test, _, _ = tu.load_set_of_anns(
print('Number of files', len(data_test)) test_sets, classes_to_ignore, params["events_of_interest"]
)
print("Number of files", len(data_test))
# train loader # train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True) train_dataset = adl.AudioLoader(data_train, params, is_train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'], train_loader = torch.utils.data.DataLoader(
shuffle=True, num_workers=params['num_workers'], pin_memory=True) train_dataset,
batch_size=params["batch_size"],
shuffle=True,
num_workers=params["num_workers"],
pin_memory=True,
)
# test loader - batch size of one because of variable file length # test loader - batch size of one because of variable file length
test_dataset = adl.AudioLoader(data_test, params, is_train=False) test_dataset = adl.AudioLoader(data_test, params, is_train=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, test_loader = torch.utils.data.DataLoader(
shuffle=False, num_workers=params['num_workers'], pin_memory=True) test_dataset,
batch_size=1,
shuffle=False,
num_workers=params["num_workers"],
pin_memory=True,
)
inputs_train = next(iter(train_loader)) inputs_train = next(iter(train_loader))
params['ip_height'] = inputs_train['spec'].shape[2] params["ip_height"] = inputs_train["spec"].shape[2]
print('\ntrain batch size :', inputs_train['spec'].shape) print("\ntrain batch size :", inputs_train["spec"].shape)
assert(params_train['model_name'] == 'Net2DFast') assert params_train["model_name"] == "Net2DFast"
print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n') print(
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
)
# set the number of output classes # set the number of output classes
num_filts = model.conv_classes_op.in_channels num_filts = model.conv_classes_op.in_channels
k_size = model.conv_classes_op.kernel_size k_size = model.conv_classes_op.kernel_size
pad = model.conv_classes_op.padding pad = model.conv_classes_op.padding
model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad) model.conv_classes_op = torch.nn.Conv2d(
model.conv_classes_op.to(params['device']) num_filts,
len(params["class_names"]) + 1,
kernel_size=k_size,
padding=pad,
)
model.conv_classes_op.to(params["device"])
if args['finetune_only_last_layer']: if args["finetune_only_last_layer"]:
print('\nOnly finetuning the final layers.\n') print("\nOnly finetuning the final layers.\n")
train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op'] train_layers_i = [
train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i] "conv_classes",
"conv_classes_op",
"conv_size",
"conv_size_op",
]
train_layers = [tt + ".weight" for tt in train_layers_i] + [
tt + ".bias" for tt in train_layers_i
]
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if name in train_layers: if name in train_layers:
param.requires_grad = True param.requires_grad = True
else: else:
param.requires_grad = False param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader)) scheduler = CosineAnnealingLR(
if params['train_loss'] == 'mse': optimizer, params["num_epochs"] * len(train_loader)
)
if params["train_loss"] == "mse":
det_criterion = losses.mse_loss det_criterion = losses.mse_loss
elif params['train_loss'] == 'focal': elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss det_criterion = losses.focal_loss
# plotting # plotting
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1, train_plt_ls = pu.LossPlotter(
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True) params["experiment"] + "train_loss.png",
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1, params["num_epochs"] + 1,
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True) ["train_loss"],
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1, None,
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', '']) None,
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1, ["epoch", "train_loss"],
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec']) logy=True,
)
test_plt_ls = pu.LossPlotter(
params["experiment"] + "test_loss.png",
params["num_epochs"] + 1,
["test_loss"],
None,
None,
["epoch", "test_loss"],
logy=True,
)
test_plt = pu.LossPlotter(
params["experiment"] + "test.png",
params["num_epochs"] + 1,
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
[0, 1],
None,
["epoch", ""],
)
test_plt_class = pu.LossPlotter(
params["experiment"] + "test_avg_prec.png",
params["num_epochs"] + 1,
params["class_names_short"],
[0, 1],
params["class_names_short"],
["epoch", "avg_prec"],
)
# main train loop # main train loop
for epoch in range(0, params['num_epochs']+1): for epoch in range(0, params["num_epochs"] + 1):
train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params) train_loss = tm.train(
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']]) model,
epoch,
train_loader,
det_criterion,
optimizer,
scheduler,
params,
)
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
if epoch % params['num_eval_epochs'] == 0: if epoch % params["num_eval_epochs"] == 0:
# detection accuracy on test set # detection accuracy on test set
test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params) test_res, test_loss = tm.test(
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']]) model, epoch, test_loader, det_criterion, params
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'], )
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']]) test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']]) test_plt.update_and_save(
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res) epoch,
[
test_res["avg_prec"],
test_res["rec_at_x"],
test_res["avg_prec_class"],
test_res["file_acc"],
test_res["top_class"]["avg_prec"],
],
)
test_plt_class.update_and_save(
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
)
pu.plot_pr_curve_class(
params["experiment"], "test_pr", "test_pr", test_res
)
# save finetuned model # save finetuned model
print('saving model to: ' + params['model_file_name']) print("saving model to: " + params["model_file_name"])
op_state = {'epoch': epoch + 1, op_state = {
'state_dict': model.state_dict(), "epoch": epoch + 1,
'params' : params} "state_dict": model.state_dict(),
torch.save(op_state, params['model_file_name']) "params": params,
}
torch.save(op_state, params["model_file_name"])
# save an image with associated prediction for each batch in the test set # save an image with associated prediction for each batch in the test set
if not args['do_not_save_images']: if not args["do_not_save_images"]:
tm.save_images_batch(model, test_loader, params) tm.save_images_batch(model, test_loader, params)

View File

@ -1,32 +1,33 @@
import numpy as np
import argparse import argparse
import os
import json import json
import os
import sys import sys
sys.path.append(os.path.join('..', '..'))
import numpy as np
sys.path.append(os.path.join("..", ".."))
import bat_detect.train.train_utils as tu import bat_detect.train.train_utils as tu
def print_dataset_stats(data, split_name, classes_to_ignore): def print_dataset_stats(data, split_name, classes_to_ignore):
print('\nSplit:', split_name) print("\nSplit:", split_name)
print('Num files:', len(data)) print("Num files:", len(data))
class_cnts = {} class_cnts = {}
for dd in data: for dd in data:
for aa in dd['annotation']: for aa in dd["annotation"]:
if aa['class'] not in classes_to_ignore: if aa["class"] not in classes_to_ignore:
if aa['class'] in class_cnts: if aa["class"] in class_cnts:
class_cnts[aa['class']] += 1 class_cnts[aa["class"]] += 1
else: else:
class_cnts[aa['class']] = 1 class_cnts[aa["class"]] = 1
if len(class_cnts) == 0: if len(class_cnts) == 0:
class_names = [] class_names = []
else: else:
class_names = np.sort([*class_cnts]).tolist() class_names = np.sort([*class_cnts]).tolist()
print('Class count:') print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5 str_len = np.max([len(cc) for cc in class_names]) + 5
for ii, cc in enumerate(class_names): for ii, cc in enumerate(class_names):
@ -41,111 +42,165 @@ def load_file_names(file_name):
with open(file_name) as da: with open(file_name) as da:
files = [line.rstrip() for line in da.readlines()] files = [line.rstrip() for line in da.readlines()]
for ff in files: for ff in files:
if ff.lower()[-3:] != 'wav': if ff.lower()[-3:] != "wav":
print('Error: Filenames need to end in .wav - ', ff) print("Error: Filenames need to end in .wav - ", ff)
assert(False) assert False
else: else:
print('Error: Input file not found - ', file_name) print("Error: Input file not found - ", file_name)
assert(False) assert False
return files return files
if __name__ == "__main__": if __name__ == "__main__":
info_str = '\nBatDetect - Prepare Data for Finetuning\n' info_str = "\nBatDetect - Prepare Data for Finetuning\n"
print(info_str) print(info_str)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('dataset_name', type=str, help='Name to call your dataset') parser.add_argument(
parser.add_argument('audio_dir', type=str, help='Input directory for audio') "dataset_name", type=str, help="Name to call your dataset"
parser.add_argument('ann_dir', type=str, help='Input directory for where the audio annotations are stored') )
parser.add_argument('op_dir', type=str, help='Path where the train and test splits will be stored') parser.add_argument("audio_dir", type=str, help="Input directory for audio")
parser.add_argument('--percent_val', type=float, default=0.20, parser.add_argument(
help='Hold out this much data for validation. Should be number between 0 and 1') "ann_dir",
parser.add_argument('--rand_seed', type=int, default=2001, type=str,
help='Random seed used for creating the validation split') help="Input directory for where the audio annotations are stored",
parser.add_argument('--train_file', type=str, default='', )
help='Text file where each line is a wav file in train split') parser.add_argument(
parser.add_argument('--test_file', type=str, default='', "op_dir",
help='Text file where each line is a wav file in test split') type=str,
parser.add_argument('--input_class_names', type=str, default='', help="Path where the train and test splits will be stored",
help='Specify names of classes that you want to change. Separate with ";"') )
parser.add_argument('--output_class_names', type=str, default='', parser.add_argument(
help='New class names to use instead. One to one mapping with "--input_class_names". \ "--percent_val",
Separate with ";"') type=float,
default=0.20,
help="Hold out this much data for validation. Should be number between 0 and 1",
)
parser.add_argument(
"--rand_seed",
type=int,
default=2001,
help="Random seed used for creating the validation split",
)
parser.add_argument(
"--train_file",
type=str,
default="",
help="Text file where each line is a wav file in train split",
)
parser.add_argument(
"--test_file",
type=str,
default="",
help="Text file where each line is a wav file in test split",
)
parser.add_argument(
"--input_class_names",
type=str,
default="",
help='Specify names of classes that you want to change. Separate with ";"',
)
parser.add_argument(
"--output_class_names",
type=str,
default="",
help='New class names to use instead. One to one mapping with "--input_class_names". \
Separate with ";"',
)
args = vars(parser.parse_args()) args = vars(parser.parse_args())
np.random.seed(args["rand_seed"])
np.random.seed(args['rand_seed']) classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
generic_class = ["Bat"]
events_of_interest = ["Echolocation"]
classes_to_ignore = ['', ' ', 'Unknown', 'Not Bat'] if args["input_class_names"] != "" and args["output_class_names"] != "":
generic_class = ['Bat']
events_of_interest = ['Echolocation']
if args['input_class_names'] != '' and args['output_class_names'] != '':
# change the names of the classes # change the names of the classes
ip_names = args['input_class_names'].split(';') ip_names = args["input_class_names"].split(";")
op_names = args['output_class_names'].split(';') op_names = args["output_class_names"].split(";")
name_dict = dict(zip(ip_names, op_names)) name_dict = dict(zip(ip_names, op_names))
else: else:
name_dict = False name_dict = False
# load annotations # load annotations
data_all, _, _ = tu.load_set_of_anns({'ann_path': args['ann_dir'], 'wav_path': args['audio_dir']}, data_all, _, _ = tu.load_set_of_anns(
classes_to_ignore, events_of_interest, False, False, {"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
list_of_anns=True, filter_issues=True, name_replace=name_dict) classes_to_ignore,
events_of_interest,
False,
False,
list_of_anns=True,
filter_issues=True,
name_replace=name_dict,
)
print('Dataset name: ' + args['dataset_name']) print("Dataset name: " + args["dataset_name"])
print('Audio directory: ' + args['audio_dir']) print("Audio directory: " + args["audio_dir"])
print('Annotation directory: ' + args['ann_dir']) print("Annotation directory: " + args["ann_dir"])
print('Ouput directory: ' + args['op_dir']) print("Ouput directory: " + args["op_dir"])
print('Num annotated files: ' + str(len(data_all))) print("Num annotated files: " + str(len(data_all)))
if args['train_file'] != '' and args['test_file'] != '': if args["train_file"] != "" and args["test_file"] != "":
# user has specifed the train / test split # user has specifed the train / test split
train_files = load_file_names(args['train_file']) train_files = load_file_names(args["train_file"])
test_files = load_file_names(args['test_file']) test_files = load_file_names(args["test_file"])
file_names_all = [dd['id'] for dd in data_all] file_names_all = [dd["id"] for dd in data_all]
train_inds = [file_names_all.index(ff) for ff in train_files if ff in file_names_all] train_inds = [
test_inds = [file_names_all.index(ff) for ff in test_files if ff in file_names_all] file_names_all.index(ff)
for ff in train_files
if ff in file_names_all
]
test_inds = [
file_names_all.index(ff)
for ff in test_files
if ff in file_names_all
]
else: else:
# split the data into train and test at the file level # split the data into train and test at the file level
num_exs = len(data_all) num_exs = len(data_all)
test_inds = np.random.choice(np.arange(num_exs), int(num_exs*args['percent_val']), replace=False) test_inds = np.random.choice(
np.arange(num_exs),
int(num_exs * args["percent_val"]),
replace=False,
)
test_inds = np.sort(test_inds) test_inds = np.sort(test_inds)
train_inds = np.setdiff1d(np.arange(num_exs), test_inds) train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
data_train = [data_all[ii] for ii in train_inds] data_train = [data_all[ii] for ii in train_inds]
data_test = [data_all[ii] for ii in test_inds] data_test = [data_all[ii] for ii in test_inds]
if not os.path.isdir(args['op_dir']): if not os.path.isdir(args["op_dir"]):
os.makedirs(args['op_dir']) os.makedirs(args["op_dir"])
op_name = os.path.join(args['op_dir'], args['dataset_name']) op_name = os.path.join(args["op_dir"], args["dataset_name"])
op_name_train = op_name + '_TRAIN.json' op_name_train = op_name + "_TRAIN.json"
op_name_test = op_name + '_TEST.json' op_name_test = op_name + "_TEST.json"
class_un_train = print_dataset_stats(data_train, 'Train', classes_to_ignore) class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore)
class_un_test = print_dataset_stats(data_test, 'Test', classes_to_ignore) class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
if len(data_train) > 0 and len(data_test) > 0: if len(data_train) > 0 and len(data_test) > 0:
if class_un_train != class_un_test: if class_un_train != class_un_test:
print('\nError: some classes are not in both the training and test sets.\ print(
\nTry a different random seed "--rand_seed".') '\nError: some classes are not in both the training and test sets.\
\nTry a different random seed "--rand_seed".'
)
assert False assert False
print('\n') print("\n")
if len(data_train) == 0: if len(data_train) == 0:
print('No train annotations to save') print("No train annotations to save")
else: else:
print('Saving: ', op_name_train) print("Saving: ", op_name_train)
with open(op_name_train, 'w') as da: with open(op_name_train, "w") as da:
json.dump(data_train, da, indent=2) json.dump(data_train, da, indent=2)
if len(data_test) == 0: if len(data_test) == 0:
print('No test annotations to save') print("No test annotations to save")
else: else:
print('Saving: ', op_name_test) print("Saving: ", op_name_test)
with open(op_name_test, 'w') as da: with open(op_name_test, "w") as da:
json.dump(data_test, da, indent=2) json.dump(data_test, da, indent=2)

View File

@ -1,71 +1,144 @@
import torch
import random
import numpy as np
import copy import copy
from typing import Tuple
import librosa import librosa
import numpy as np
import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import os
import sys
sys.path.append(os.path.join('..', '..'))
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
from bat_detect.types import AnnotationGroup, HeatmapParameters
def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params): def generate_gt_heatmaps(
spec_op_shape: Tuple[int, int],
sampling_rate: int,
ann: AnnotationGroup,
params: HeatmapParameters,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
"""Generate ground truth heatmaps from annotations.
Parameters
----------
spec_op_shape : Tuple[int, int]
Shape of the input spectrogram.
sampling_rate : int
Sampling rate of the input audio in Hz.
ann : AnnotationGroup
Dictionary containing the annotation information.
params : HeatmapParameters
Parameters controlling the generation of the heatmaps.
Returns
-------
y_2d_det : np.ndarray
2D heatmap of the presence of an event.
y_2d_size : np.ndarray
2D heatmap of the size of the bounding box associated to event.
y_2d_classes : np.ndarray
3D array containing the ground-truth class probabilities for each
pixel.
ann_aug : AnnotationGroup
A dictionary containing the annotation information of the
annotations that are within the input spectrogram, augmented with
the x and y indices of their pixel location in the input spectrogram.
"""
# spec may be resized on input into the network # spec may be resized on input into the network
num_classes = len(params['class_names']) num_classes = len(params["class_names"])
op_height = spec_op_shape[0] op_height = spec_op_shape[0]
op_width = spec_op_shape[1] op_width = spec_op_shape[1]
freq_per_bin = (params['max_freq'] - params['min_freq']) / op_height freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
# start and end times # start and end times
x_pos_start = au.time_to_x_coords(ann['start_times'], sampling_rate, x_pos_start = au.time_to_x_coords(
params['fft_win_length'], params['fft_overlap']) ann["start_times"],
x_pos_start = (params['resize_factor']*x_pos_start).astype(np.int) sampling_rate,
x_pos_end = au.time_to_x_coords(ann['end_times'], sampling_rate, params["fft_win_length"],
params['fft_win_length'], params['fft_overlap']) params["fft_overlap"],
x_pos_end = (params['resize_factor']*x_pos_end).astype(np.int) )
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
x_pos_end = au.time_to_x_coords(
ann["end_times"],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
# location on y axis i.e. frequency # location on y axis i.e. frequency
y_pos_low = (ann['low_freqs'] - params['min_freq']) / freq_per_bin y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_low = (op_height - y_pos_low).astype(np.int) y_pos_low = (op_height - y_pos_low).astype(np.int)
y_pos_high = (ann['high_freqs'] - params['min_freq']) / freq_per_bin y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int) y_pos_high = (op_height - y_pos_high).astype(np.int)
bb_widths = x_pos_end - x_pos_start bb_widths = x_pos_end - x_pos_start
bb_heights = (y_pos_low - y_pos_high) bb_heights = y_pos_low - y_pos_high
valid_inds = np.where((x_pos_start >= 0) & (x_pos_start < op_width) & # Only include annotations that are within the input spectrogram
(y_pos_low >= 0) & (y_pos_low < (op_height-1)))[0] valid_inds = np.where(
(x_pos_start >= 0)
& (x_pos_start < op_width)
& (y_pos_low >= 0)
& (y_pos_low < (op_height - 1))
)[0]
ann_aug = {} ann_aug: AnnotationGroup = {
ann_aug['x_inds'] = x_pos_start[valid_inds] "start_times": ann["start_times"][valid_inds],
ann_aug['y_inds'] = y_pos_low[valid_inds] "end_times": ann["end_times"][valid_inds],
keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids'] "high_freqs": ann["high_freqs"][valid_inds],
for kk in keys: "low_freqs": ann["low_freqs"][valid_inds],
ann_aug[kk] = ann[kk][valid_inds] "class_ids": ann["class_ids"][valid_inds],
"individual_ids": ann["individual_ids"][valid_inds],
}
ann_aug["x_inds"] = x_pos_start[valid_inds]
ann_aug["y_inds"] = y_pos_low[valid_inds]
# keys = [
# "start_times",
# "end_times",
# "high_freqs",
# "low_freqs",
# "class_ids",
# "individual_ids",
# ]
# for kk in keys:
# ann_aug[kk] = ann[kk][valid_inds]
# if the number of calls is only 1, then it is unique # if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage # TODO would be better if we found these unique calls at the merging stage
if len(ann_aug['individual_ids']) == 1: if len(ann_aug["individual_ids"]) == 1:
ann_aug['individual_ids'][0] = 0 ann_aug["individual_ids"][0] = 0
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32) y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class # num classes and "background" class
y_2d_classes = np.zeros((num_classes+1, op_height, op_width), dtype=np.float32) y_2d_classes: np.ndarray = np.zeros(
(num_classes + 1, op_height, op_width), dtype=np.float32
)
# create 2D ground truth heatmaps # create 2D ground truth heatmaps
for ii in valid_inds: for ii in valid_inds:
draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma']) draw_gaussian(
#draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) y_2d_det[0, :],
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
)
# draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii] y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii] y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
cls_id = ann['class_ids'][ii] cls_id = ann["class_ids"][ii]
if cls_id > -1: if cls_id > -1:
draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma']) draw_gaussian(
#draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) y_2d_classes[cls_id, :],
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
)
# draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
# be careful as this will have a 1.0 places where we have event but dont know gt class # be careful as this will have a 1.0 places where we have event but dont know gt class
# this will be masked in training anyway # this will be masked in training anyway
@ -96,20 +169,24 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None):
x = np.arange(0, size, 1, np.float32) x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis] y = x[:, np.newaxis]
x0 = y0 = size // 2 x0 = y0 = size // 2
#g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
g = np.exp(- ((x - x0) ** 2)/(2 * sigmax ** 2) - ((y - y0) ** 2)/(2 * sigmay ** 2)) g = np.exp(
-((x - x0) ** 2) / (2 * sigmax**2)
- ((y - y0) ** 2) / (2 * sigmay**2)
)
g_x = max(0, -ul[0]), min(br[0], h) - ul[0] g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1] g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
img_x = max(0, ul[0]), min(br[0], h) img_x = max(0, ul[0]), min(br[0], h)
img_y = max(0, ul[1]), min(br[1], w) img_y = max(0, ul[1]), min(br[1], w)
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]) g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
)
return True return True
def pad_aray(ip_array, pad_size): def pad_aray(ip_array, pad_size):
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int)*-1)) return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
def warp_spec_aug(spec, ann, return_spec_for_viz, params): def warp_spec_aug(spec, ann, return_spec_for_viz, params):
@ -121,24 +198,37 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params):
if return_spec_for_viz: if return_spec_for_viz:
assert False assert False
delta = params['stretch_squeeze_delta'] delta = params["stretch_squeeze_delta"]
op_size = (spec.shape[1], spec.shape[2]) op_size = (spec.shape[1], spec.shape[2])
resize_fract_r = np.random.rand()*delta*2 - delta + 1.0 resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
resize_amt = int(spec.shape[2]*resize_fract_r) resize_amt = int(spec.shape[2] * resize_fract_r)
if resize_amt >= spec.shape[2]: if resize_amt >= spec.shape[2]:
spec_r = torch.cat((spec, torch.zeros((1, spec.shape[1], resize_amt-spec.shape[2]), dtype=spec.dtype)), 2) spec_r = torch.cat(
(
spec,
torch.zeros(
(1, spec.shape[1], resize_amt - spec.shape[2]),
dtype=spec.dtype,
),
),
2,
)
else: else:
spec_r = spec[:, :, :resize_amt] spec_r = spec[:, :, :resize_amt]
spec = F.interpolate(spec_r.unsqueeze(0), size=op_size, mode='bilinear', align_corners=False).squeeze(0) spec = F.interpolate(
ann['start_times'] *= (1.0/resize_fract_r) spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
ann['end_times'] *= (1.0/resize_fract_r) ).squeeze(0)
ann["start_times"] *= 1.0 / resize_fract_r
ann["end_times"] *= 1.0 / resize_fract_r
return spec return spec
def mask_time_aug(spec, params): def mask_time_aug(spec, params):
# Mask out a random block of time - repeat up to 3 times # Mask out a random block of time - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition # SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
fm = torchaudio.transforms.TimeMasking(int(spec.shape[1]*params['mask_max_time_perc'])) fm = torchaudio.transforms.TimeMasking(
int(spec.shape[1] * params["mask_max_time_perc"])
)
for ii in range(np.random.randint(1, 4)): for ii in range(np.random.randint(1, 4)):
spec = fm(spec) spec = fm(spec)
return spec return spec
@ -147,40 +237,65 @@ def mask_time_aug(spec, params):
def mask_freq_aug(spec, params): def mask_freq_aug(spec, params):
# Mask out a random frequncy range - repeat up to 3 times # Mask out a random frequncy range - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition # SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
fm = torchaudio.transforms.FrequencyMasking(int(spec.shape[1]*params['mask_max_freq_perc'])) fm = torchaudio.transforms.FrequencyMasking(
int(spec.shape[1] * params["mask_max_freq_perc"])
)
for ii in range(np.random.randint(1, 4)): for ii in range(np.random.randint(1, 4)):
spec = fm(spec) spec = fm(spec)
return spec return spec
def scale_vol_aug(spec, params): def scale_vol_aug(spec, params):
return spec * np.random.random()*params['spec_amp_scaling'] return spec * np.random.random() * params["spec_amp_scaling"]
def echo_aug(audio, sampling_rate, params): def echo_aug(audio, sampling_rate, params):
sample_offset = int(params['echo_max_delay']*np.random.random()*sampling_rate) + 1 sample_offset = (
audio[:-sample_offset] += np.random.random()*audio[sample_offset:] int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
)
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio return audio
def resample_aug(audio, sampling_rate, params): def resample_aug(audio, sampling_rate, params):
sampling_rate_old = sampling_rate sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(params['aug_sampling_rates']) sampling_rate = np.random.choice(params["aug_sampling_rates"])
audio = librosa.resample(audio, sampling_rate_old, sampling_rate, res_type='polyphase') audio = librosa.resample(
audio,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], audio = au.pad_audio(
params['fft_overlap'], params['resize_factor'], audio,
params['spec_divide_factor'], params['spec_train_width']) sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
params["resize_factor"],
params["spec_divide_factor"],
params["spec_train_width"],
)
duration = audio.shape[0] / float(sampling_rate) duration = audio.shape[0] / float(sampling_rate)
return audio, sampling_rate, duration return audio, sampling_rate, duration
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
if sampling_rate != sampling_rate2: if sampling_rate != sampling_rate2:
audio2 = librosa.resample(audio2, sampling_rate2, sampling_rate, res_type='polyphase') audio2 = librosa.resample(
audio2,
orig_sr=sampling_rate2,
target_sr=sampling_rate,
res_type="polyphase",
)
sampling_rate2 = sampling_rate sampling_rate2 = sampling_rate
if audio2.shape[0] < num_samples: if audio2.shape[0] < num_samples:
audio2 = np.hstack((audio2, np.zeros((num_samples-audio2.shape[0]), dtype=audio2.dtype))) audio2 = np.hstack(
(
audio2,
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
)
)
elif audio2.shape[0] > num_samples: elif audio2.shape[0] > num_samples:
audio2 = audio2[:num_samples] audio2 = audio2[:num_samples]
return audio2, sampling_rate2 return audio2, sampling_rate2
@ -189,26 +304,32 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2): def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
# resample so they are the same # resample so they are the same
audio2, sampling_rate2 = resample_audio(audio.shape[0], sampling_rate, audio2, sampling_rate2) audio2, sampling_rate2 = resample_audio(
audio.shape[0], sampling_rate, audio2, sampling_rate2
)
# # set mean and std to be the same # # set mean and std to be the same
# audio2 = (audio2 - audio2.mean()) # audio2 = (audio2 - audio2.mean())
# audio2 = (audio2/audio2.std())*audio.std() # audio2 = (audio2/audio2.std())*audio.std()
# audio2 = audio2 + audio.mean() # audio2 = audio2 + audio.mean()
if ann['annotated'] and (ann2['annotated']) and \ if (
(sampling_rate2 == sampling_rate) and (audio.shape[0] == audio2.shape[0]): ann["annotated"]
comb_weight = 0.3 + np.random.random()*0.4 and (ann2["annotated"])
audio = comb_weight*audio + (1-comb_weight)*audio2 and (sampling_rate2 == sampling_rate)
inds = np.argsort(np.hstack((ann['start_times'], ann2['start_times']))) and (audio.shape[0] == audio2.shape[0])
):
comb_weight = 0.3 + np.random.random() * 0.4
audio = comb_weight * audio + (1 - comb_weight) * audio2
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
for kk in ann.keys(): for kk in ann.keys():
# when combining calls from different files, assume they come from different individuals # when combining calls from different files, assume they come from different individuals
if kk == 'individual_ids': if kk == "individual_ids":
if (ann[kk]>-1).sum() > 0: if (ann[kk] > -1).sum() > 0:
ann2[kk][ann2[kk]>-1] += np.max(ann[kk][ann[kk]>-1]) + 1 ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
if (kk != 'class_id_file') and (kk != 'annotated'): if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds] ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
return audio, ann return audio, ann
@ -227,53 +348,70 @@ class AudioLoader(torch.utils.data.Dataset):
# filter out unused annotation here # filter out unused annotation here
filtered_annotations = [] filtered_annotations = []
for ii, aa in enumerate(dd['annotation']): for ii, aa in enumerate(dd["annotation"]):
if 'individual' in aa.keys(): if "individual" in aa.keys():
aa['individual'] = int(aa['individual']) aa["individual"] = int(aa["individual"])
# if only one call labeled it has to be from the same individual # if only one call labeled it has to be from the same individual
if len(dd['annotation']) == 1: if len(dd["annotation"]) == 1:
aa['individual'] = 0 aa["individual"] = 0
# convert class name into class label # convert class name into class label
if aa['class'] in self.params['class_names']: if aa["class"] in self.params["class_names"]:
aa['class_id'] = self.params['class_names'].index(aa['class']) aa["class_id"] = self.params["class_names"].index(
aa["class"]
)
else: else:
aa['class_id'] = -1 aa["class_id"] = -1
if aa['class'] not in self.params['classes_to_ignore']: if aa["class"] not in self.params["classes_to_ignore"]:
filtered_annotations.append(aa) filtered_annotations.append(aa)
dd['annotation'] = filtered_annotations dd["annotation"] = filtered_annotations
dd['start_times'] = np.array([aa['start_time'] for aa in dd['annotation']]) dd["start_times"] = np.array(
dd['end_times'] = np.array([aa['end_time'] for aa in dd['annotation']]) [aa["start_time"] for aa in dd["annotation"]]
dd['high_freqs'] = np.array([float(aa['high_freq']) for aa in dd['annotation']]) )
dd['low_freqs'] = np.array([float(aa['low_freq']) for aa in dd['annotation']]) dd["end_times"] = np.array(
dd['class_ids'] = np.array([aa['class_id'] for aa in dd['annotation']]).astype(np.int) [aa["end_time"] for aa in dd["annotation"]]
dd['individual_ids'] = np.array([aa['individual'] for aa in dd['annotation']]).astype(np.int) )
dd["high_freqs"] = np.array(
[float(aa["high_freq"]) for aa in dd["annotation"]]
)
dd["low_freqs"] = np.array(
[float(aa["low_freq"]) for aa in dd["annotation"]]
)
dd["class_ids"] = np.array(
[aa["class_id"] for aa in dd["annotation"]]
).astype(np.int)
dd["individual_ids"] = np.array(
[aa["individual"] for aa in dd["annotation"]]
).astype(np.int)
# file level class name # file level class name
dd['class_id_file'] = -1 dd["class_id_file"] = -1
if 'class_name' in dd.keys(): if "class_name" in dd.keys():
if dd['class_name'] in self.params['class_names']: if dd["class_name"] in self.params["class_names"]:
dd['class_id_file'] = self.params['class_names'].index(dd['class_name']) dd["class_id_file"] = self.params["class_names"].index(
dd["class_name"]
)
self.data_anns.append(dd) self.data_anns.append(dd)
ann_cnt = [len(aa['annotation']) for aa in self.data_anns] ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
self.max_num_anns = 2*np.max(ann_cnt) # x2 because we may be combining files during training self.max_num_anns = 2 * np.max(
ann_cnt
) # x2 because we may be combining files during training
print('\n') print("\n")
if dataset_name is not None: if dataset_name is not None:
print('Dataset : ' + dataset_name) print("Dataset : " + dataset_name)
if self.is_train: if self.is_train:
print('Split type : train') print("Split type : train")
else: else:
print('Split type : test') print("Split type : test")
print('Num files : ' + str(len(self.data_anns))) print("Num files : " + str(len(self.data_anns)))
print('Num calls : ' + str(np.sum(ann_cnt))) print("Num calls : " + str(np.sum(ann_cnt)))
def get_file_and_anns(self, index=None): def get_file_and_anns(self, index=None):
@ -281,110 +419,169 @@ class AudioLoader(torch.utils.data.Dataset):
if index == None: if index == None:
index = np.random.randint(0, len(self.data_anns)) index = np.random.randint(0, len(self.data_anns))
audio_file = self.data_anns[index]['file_path'] audio_file = self.data_anns[index]["file_path"]
sampling_rate, audio_raw = au.load_audio_file(audio_file, self.data_anns[index]['time_exp'], sampling_rate, audio_raw = au.load_audio(
self.params['target_samp_rate'], self.params['scale_raw_audio']) audio_file,
self.data_anns[index]["time_exp"],
self.params["target_samp_rate"],
self.params["scale_raw_audio"],
)
# copy annotation # copy annotation
ann = {} ann = {}
ann['annotated'] = self.data_anns[index]['annotated'] ann["annotated"] = self.data_anns[index]["annotated"]
ann['class_id_file'] = self.data_anns[index]['class_id_file'] ann["class_id_file"] = self.data_anns[index]["class_id_file"]
keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids'] keys = [
"start_times",
"end_times",
"high_freqs",
"low_freqs",
"class_ids",
"individual_ids",
]
for kk in keys: for kk in keys:
ann[kk] = self.data_anns[index][kk].copy() ann[kk] = self.data_anns[index][kk].copy()
# if train then grab a random crop # if train then grab a random crop
if self.is_train: if self.is_train:
nfft = int(self.params['fft_win_length']*sampling_rate) nfft = int(self.params["fft_win_length"] * sampling_rate)
noverlap = int(self.params['fft_overlap']*nfft) noverlap = int(self.params["fft_overlap"] * nfft)
length_samples = self.params['spec_train_width']*(nfft - noverlap) + noverlap length_samples = (
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
)
if audio_raw.shape[0] - length_samples > 0: if audio_raw.shape[0] - length_samples > 0:
sample_crop = np.random.randint(audio_raw.shape[0] - length_samples) sample_crop = np.random.randint(
audio_raw.shape[0] - length_samples
)
else: else:
sample_crop = 0 sample_crop = 0
audio_raw = audio_raw[sample_crop:sample_crop+length_samples] audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
ann['start_times'] = ann['start_times'] - sample_crop/float(sampling_rate) ann["start_times"] = ann["start_times"] - sample_crop / float(
ann['end_times'] = ann['end_times'] - sample_crop/float(sampling_rate) sampling_rate
)
ann["end_times"] = ann["end_times"] - sample_crop / float(
sampling_rate
)
# pad audio # pad audio
if self.is_train: if self.is_train:
op_spec_target_size = self.params['spec_train_width'] op_spec_target_size = self.params["spec_train_width"]
else: else:
op_spec_target_size = None op_spec_target_size = None
audio_raw = au.pad_audio(audio_raw, sampling_rate, self.params['fft_win_length'], audio_raw = au.pad_audio(
self.params['fft_overlap'], self.params['resize_factor'], audio_raw,
self.params['spec_divide_factor'], op_spec_target_size) sampling_rate,
self.params["fft_win_length"],
self.params["fft_overlap"],
self.params["resize_factor"],
self.params["spec_divide_factor"],
op_spec_target_size,
)
duration = audio_raw.shape[0] / float(sampling_rate) duration = audio_raw.shape[0] / float(sampling_rate)
# sort based on time # sort based on time
inds = np.argsort(ann['start_times']) inds = np.argsort(ann["start_times"])
for kk in ann.keys(): for kk in ann.keys():
if (kk != 'class_id_file') and (kk != 'annotated'): if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = ann[kk][inds] ann[kk] = ann[kk][inds]
return audio_raw, sampling_rate, duration, ann return audio_raw, sampling_rate, duration, ann
def __getitem__(self, index): def __getitem__(self, index):
# load audio file # load audio file
audio, sampling_rate, duration, ann = self.get_file_and_anns(index) audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
# augment on raw audio # augment on raw audio
if self.is_train and self.params['augment_at_train']: if self.is_train and self.params["augment_at_train"]:
# augment - combine with random audio file # augment - combine with random audio file
if self.params['augment_at_train_combine'] and np.random.random() < self.params['aug_prob']: if (
audio2, sampling_rate2, duration2, ann2 = self.get_file_and_anns() self.params["augment_at_train_combine"]
audio, ann = combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2) and np.random.random() < self.params["aug_prob"]
):
(
audio2,
sampling_rate2,
duration2,
ann2,
) = self.get_file_and_anns()
audio, ann = combine_audio_aug(
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
)
# simulate echo by adding delayed copy of the file # simulate echo by adding delayed copy of the file
if np.random.random() < self.params['aug_prob']: if np.random.random() < self.params["aug_prob"]:
audio = echo_aug(audio, sampling_rate, self.params) audio = echo_aug(audio, sampling_rate, self.params)
# resample the audio # resample the audio
#if np.random.random() < self.params['aug_prob']: # if np.random.random() < self.params['aug_prob']:
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params) # audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
# create spectrogram # create spectrogram
spec, spec_for_viz = au.generate_spectrogram(audio, sampling_rate, self.params, self.return_spec_for_viz) spec, spec_for_viz = au.generate_spectrogram(
rsf = self.params['resize_factor'] audio, sampling_rate, self.params, self.return_spec_for_viz
spec_op_shape = (int(self.params['spec_height']*rsf), int(spec.shape[1]*rsf)) )
rsf = self.params["resize_factor"]
spec_op_shape = (
int(self.params["spec_height"] * rsf),
int(spec.shape[1] * rsf),
)
# resize the spec # resize the spec
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0) spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False).squeeze(0) spec = F.interpolate(
spec, size=spec_op_shape, mode="bilinear", align_corners=False
).squeeze(0)
# augment spectrogram # augment spectrogram
if self.is_train and self.params['augment_at_train']: if self.is_train and self.params["augment_at_train"]:
if np.random.random() < self.params['aug_prob']: if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(spec, self.params) spec = scale_vol_aug(spec, self.params)
if np.random.random() < self.params['aug_prob']: if np.random.random() < self.params["aug_prob"]:
spec = warp_spec_aug(spec, ann, self.return_spec_for_viz, self.params) spec = warp_spec_aug(
spec, ann, self.return_spec_for_viz, self.params
)
if np.random.random() < self.params['aug_prob']: if np.random.random() < self.params["aug_prob"]:
spec = mask_time_aug(spec, self.params) spec = mask_time_aug(spec, self.params)
if np.random.random() < self.params['aug_prob']: if np.random.random() < self.params["aug_prob"]:
spec = mask_freq_aug(spec, self.params) spec = mask_freq_aug(spec, self.params)
outputs = {} outputs = {}
outputs['spec'] = spec outputs["spec"] = spec
if self.return_spec_for_viz: if self.return_spec_for_viz:
outputs['spec_for_viz'] = torch.from_numpy(spec_for_viz).unsqueeze(0) outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
0
)
# create ground truth heatmaps # create ground truth heatmaps
outputs['y_2d_det'], outputs['y_2d_size'], outputs['y_2d_classes'], ann_aug =\ (
generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params) outputs["y_2d_det"],
outputs["y_2d_size"],
outputs["y_2d_classes"],
ann_aug,
) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
# hack to get around requirement that all vectors are the same length in # hack to get around requirement that all vectors are the same length in
# the output batch # the output batch
pad_size = self.max_num_anns-len(ann_aug['individual_ids']) pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
outputs['is_valid'] = pad_aray(np.ones(len(ann_aug['individual_ids'])), pad_size) outputs["is_valid"] = pad_aray(
keys = ['class_ids', 'individual_ids', 'x_inds', 'y_inds', np.ones(len(ann_aug["individual_ids"])), pad_size
'start_times', 'end_times', 'low_freqs', 'high_freqs'] )
keys = [
"class_ids",
"individual_ids",
"x_inds",
"y_inds",
"start_times",
"end_times",
"low_freqs",
"high_freqs",
]
for kk in keys: for kk in keys:
outputs[kk] = pad_aray(ann_aug[kk], pad_size) outputs[kk] = pad_aray(ann_aug[kk], pad_size)
@ -394,14 +591,13 @@ class AudioLoader(torch.utils.data.Dataset):
outputs[kk] = torch.from_numpy(outputs[kk]) outputs[kk] = torch.from_numpy(outputs[kk])
# scalars # scalars
outputs['class_id_file'] = ann['class_id_file'] outputs["class_id_file"] = ann["class_id_file"]
outputs['annotated'] = ann['annotated'] outputs["annotated"] = ann["annotated"]
outputs['duration'] = duration outputs["duration"] = duration
outputs['sampling_rate'] = sampling_rate outputs["sampling_rate"] = sampling_rate
outputs['file_id'] = index outputs["file_id"] = index
return outputs return outputs
def __len__(self): def __len__(self):
return len(self.data_anns) return len(self.data_anns)

View File

@ -1,6 +1,10 @@
import numpy as np import numpy as np
from sklearn.metrics import roc_curve, auc from sklearn.metrics import (
from sklearn.metrics import accuracy_score, balanced_accuracy_score accuracy_score,
auc,
balanced_accuracy_score,
roc_curve,
)
def compute_error_auc(op_str, gt, pred, prob): def compute_error_auc(op_str, gt, pred, prob):
@ -13,8 +17,11 @@ def compute_error_auc(op_str, gt, pred, prob):
fpr, tpr, thresholds = roc_curve(gt, pred) fpr, tpr, thresholds = roc_curve(gt, pred)
roc_auc = auc(fpr, tpr) roc_auc = auc(fpr, tpr)
print(op_str + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc)) print(
#return class_acc, roc_auc op_str
+ ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc)
)
# return class_acc, roc_auc
def calc_average_precision(recall, precision): def calc_average_precision(recall, precision):
@ -25,10 +32,10 @@ def calc_average_precision(recall, precision):
# pascal 12 way # pascal 12 way
mprec = np.hstack((0, precision, 0)) mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1)) mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0]-2, -1,-1): for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii+1]) mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0]+1 inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds-1])*mprec[inds]).sum() ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return float(ave_prec) return float(ave_prec)
@ -37,7 +44,7 @@ def calc_recall_at_x(recall, precision, x=0.95):
precision[np.isnan(precision)] = 0 precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0 recall[np.isnan(recall)] = 0
inds = np.where(precision[::-1]>x)[0] inds = np.where(precision[::-1] > x)[0]
if len(inds) > 0: if len(inds) > 0:
return float(recall[::-1][inds[0]]) return float(recall[::-1][inds[0]])
else: else:
@ -51,7 +58,15 @@ def compute_affinity_1d(pred_box, gt_boxes, threshold):
return valid_detection, np.argmin(score) return valid_detection, np.argmin(score)
def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, threshold, ignore_start_end): def compute_pre_rec(
gts,
preds,
eval_mode,
class_of_interest,
num_classes,
threshold,
ignore_start_end,
):
""" """
Computes precision and recall. Assumes that each file has been exhaustively Computes precision and recall. Assumes that each file has been exhaustively
annotated. Will not count predicted detection with a start time that is within annotated. Will not count predicted detection with a start time that is within
@ -78,26 +93,40 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
for pid, pp in enumerate(preds): for pid, pp in enumerate(preds):
# filter predicted calls that are too near the start or end of the file # filter predicted calls that are too near the start or end of the file
file_dur = gts[pid]['duration'] file_dur = gts[pid]["duration"]
valid_inds = (pp['start_times'] >= ignore_start_end) & (pp['start_times'] <= (file_dur - ignore_start_end)) valid_inds = (pp["start_times"] >= ignore_start_end) & (
pp["start_times"] <= (file_dur - ignore_start_end)
)
pred_boxes.append(np.vstack((pp['start_times'][valid_inds], pp['end_times'][valid_inds], pred_boxes.append(
pp['low_freqs'][valid_inds], pp['high_freqs'][valid_inds])).T) np.vstack(
(
pp["start_times"][valid_inds],
pp["end_times"][valid_inds],
pp["low_freqs"][valid_inds],
pp["high_freqs"][valid_inds],
)
).T
)
if eval_mode == 'detection': if eval_mode == "detection":
# overall detection # overall detection
confidence.append(pp['det_probs'][valid_inds]) confidence.append(pp["det_probs"][valid_inds])
elif eval_mode == 'per_class': elif eval_mode == "per_class":
# per class # per class
confidence.append(pp['class_probs'].T[valid_inds, class_of_interest]) confidence.append(
elif eval_mode == 'top_class': pp["class_probs"].T[valid_inds, class_of_interest]
)
elif eval_mode == "top_class":
# per class - note that sometimes 'class_probs' can be num_classes+1 in size # per class - note that sometimes 'class_probs' can be num_classes+1 in size
top_class = np.argmax(pp['class_probs'].T[valid_inds, :num_classes], 1) top_class = np.argmax(
confidence.append(pp['class_probs'].T[valid_inds, top_class]) pp["class_probs"].T[valid_inds, :num_classes], 1
)
confidence.append(pp["class_probs"].T[valid_inds, top_class])
pred_class.append(top_class) pred_class.append(top_class)
# be careful, assuming the order in the list is same as GT # be careful, assuming the order in the list is same as GT
file_ids.append([pid]*valid_inds.sum()) file_ids.append([pid] * valid_inds.sum())
confidence = np.hstack(confidence) confidence = np.hstack(confidence)
file_ids = np.hstack(file_ids).astype(np.int) file_ids = np.hstack(file_ids).astype(np.int)
@ -105,7 +134,6 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
if len(pred_class) > 0: if len(pred_class) > 0:
pred_class = np.hstack(pred_class) pred_class = np.hstack(pred_class)
# extract relevant ground truth boxes # extract relevant ground truth boxes
gt_boxes = [] gt_boxes = []
gt_assigned = [] gt_assigned = []
@ -115,32 +143,42 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
for gg in gts: for gg in gts:
# filter ground truth calls that are too near the start or end of the file # filter ground truth calls that are too near the start or end of the file
file_dur = gg['duration'] file_dur = gg["duration"]
valid_inds = (gg['start_times'] >= ignore_start_end) & (gg['start_times'] <= (file_dur - ignore_start_end)) valid_inds = (gg["start_times"] >= ignore_start_end) & (
gg["start_times"] <= (file_dur - ignore_start_end)
)
# note, files with the incorrect duration will cause a problem # note, files with the incorrect duration will cause a problem
if (gg['start_times'] > file_dur).sum() > 0: if (gg["start_times"] > file_dur).sum() > 0:
print('Error: file duration incorrect for', gg['id']) print("Error: file duration incorrect for", gg["id"])
assert(False) assert False
boxes = np.vstack((gg['start_times'][valid_inds], gg['end_times'][valid_inds], boxes = np.vstack(
gg['low_freqs'][valid_inds], gg['high_freqs'][valid_inds])).T (
gen_class = gg['class_ids'][valid_inds] == -1 gg["start_times"][valid_inds],
class_ids = gg['class_ids'][valid_inds] gg["end_times"][valid_inds],
gg["low_freqs"][valid_inds],
gg["high_freqs"][valid_inds],
)
).T
gen_class = gg["class_ids"][valid_inds] == -1
class_ids = gg["class_ids"][valid_inds]
# keep track of the number of relevant ground truth calls # keep track of the number of relevant ground truth calls
if eval_mode == 'detection': if eval_mode == "detection":
# all valid ones # all valid ones
num_positives += len(gg['start_times'][valid_inds]) num_positives += len(gg["start_times"][valid_inds])
elif eval_mode == 'per_class': elif eval_mode == "per_class":
# all valid ones with class of interest # all valid ones with class of interest
num_positives += (gg['class_ids'][valid_inds] == class_of_interest).sum() num_positives += (
elif eval_mode == 'top_class': gg["class_ids"][valid_inds] == class_of_interest
).sum()
elif eval_mode == "top_class":
# all valid ones with non generic class # all valid ones with non generic class
num_positives += (gg['class_ids'][valid_inds] > -1).sum() num_positives += (gg["class_ids"][valid_inds] > -1).sum()
# find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1) # find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1)
if eval_mode == 'per_class': if eval_mode == "per_class":
class_inds = (class_ids == class_of_interest) | (class_ids == -1) class_inds = (class_ids == class_of_interest) | (class_ids == -1)
boxes = boxes[class_inds, :] boxes = boxes[class_inds, :]
gen_class = gen_class[class_inds] gen_class = gen_class[class_inds]
@ -151,25 +189,27 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
gt_generic_class.append(gen_class) gt_generic_class.append(gen_class)
gt_class.append(class_ids) gt_class.append(class_ids)
# loop through detections and keep track of those that have been assigned # loop through detections and keep track of those that have been assigned
true_pos = np.zeros(confidence.shape[0]) true_pos = np.zeros(confidence.shape[0])
valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
sorted_inds = np.argsort(confidence)[::-1] # sort high to low sorted_inds = np.argsort(confidence)[::-1] # sort high to low
for ii, ind in enumerate(sorted_inds): for ii, ind in enumerate(sorted_inds):
gt_id = file_ids[ind] gt_id = file_ids[ind]
valid_det = False valid_det = False
if gt_boxes[gt_id].shape[0] > 0: if gt_boxes[gt_id].shape[0] > 0:
# compute overlap # compute overlap
valid_det, det_ind = compute_affinity_1d(pred_boxes[ind], gt_boxes[gt_id], valid_det, det_ind = compute_affinity_1d(
threshold) pred_boxes[ind], gt_boxes[gt_id], threshold
)
# valid detection that has not already been assigned # valid detection that has not already been assigned
if valid_det and (gt_assigned[gt_id][det_ind] == 0): if valid_det and (gt_assigned[gt_id][det_ind] == 0):
count_as_true_pos = True count_as_true_pos = True
if eval_mode == 'top_class' and (gt_class[gt_id][det_ind] != pred_class[ind]): if eval_mode == "top_class" and (
gt_class[gt_id][det_ind] != pred_class[ind]
):
# needs to be the same class # needs to be the same class
count_as_true_pos = False count_as_true_pos = False
@ -181,40 +221,43 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
# if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True) # if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True)
# and eval_mode != 'detection', then ignore it # and eval_mode != 'detection', then ignore it
if gt_generic_class[gt_id][det_ind]: if gt_generic_class[gt_id][det_ind]:
if eval_mode == 'per_class' or eval_mode == 'top_class': if eval_mode == "per_class" or eval_mode == "top_class":
valid_inds[ii] = False valid_inds[ii] = False
# store threshold values - used for plotting # store threshold values - used for plotting
conf_sorted = np.sort(confidence)[::-1][valid_inds] conf_sorted = np.sort(confidence)[::-1][valid_inds]
thresholds = np.linspace(0.1, 0.9, 9) thresholds = np.linspace(0.1, 0.9, 9)
thresholds_inds = np.zeros(len(thresholds), dtype=np.int) thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
for ii, tt in enumerate(thresholds): for ii, tt in enumerate(thresholds):
thresholds_inds[ii] = np.argmin(conf_sorted > tt) thresholds_inds[ii] = np.argmin(conf_sorted > tt)
thresholds_inds[thresholds_inds==0] = -1 thresholds_inds[thresholds_inds == 0] = -1
# compute precision and recall # compute precision and recall
true_pos = true_pos[valid_inds] true_pos = true_pos[valid_inds]
false_pos_c = np.cumsum(1-true_pos) false_pos_c = np.cumsum(1 - true_pos)
true_pos_c = np.cumsum(true_pos) true_pos_c = np.cumsum(true_pos)
recall = true_pos_c / num_positives recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(true_pos_c + false_pos_c, np.finfo(np.float64).eps) precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c, np.finfo(np.float64).eps
)
results = {} results = {}
results['recall'] = recall results["recall"] = recall
results['precision'] = precision results["precision"] = precision
results['num_gt'] = num_positives results["num_gt"] = num_positives
results['thresholds'] = thresholds results["thresholds"] = thresholds
results['thresholds_inds'] = thresholds_inds results["thresholds_inds"] = thresholds_inds
if num_positives == 0: if num_positives == 0:
results['avg_prec'] = np.nan results["avg_prec"] = np.nan
results['rec_at_x'] = np.nan results["rec_at_x"] = np.nan
else: else:
results['avg_prec'] = np.round(calc_average_precision(recall, precision), 5) results["avg_prec"] = np.round(
results['rec_at_x'] = np.round(calc_recall_at_x(recall, precision), 5) calc_average_precision(recall, precision), 5
)
results["rec_at_x"] = np.round(calc_recall_at_x(recall, precision), 5)
return results return results
@ -230,19 +273,19 @@ def compute_file_accuracy_simple(gts, preds, num_classes):
gt_valid = [] gt_valid = []
pred_valid = [] pred_valid = []
for ii in range(len(gts)): for ii in range(len(gts)):
gt_class = np.unique(gts[ii]['class_ids']) gt_class = np.unique(gts[ii]["class_ids"])
if len(gt_class) == 1 and gt_class[0] != -1: if len(gt_class) == 1 and gt_class[0] != -1:
gt_valid.append(gt_class[0]) gt_valid.append(gt_class[0])
pred = preds[ii]['class_probs'][:num_classes, :].T pred = preds[ii]["class_probs"][:num_classes, :].T
pred_valid.append(np.argmax(pred.mean(0))) pred_valid.append(np.argmax(pred.mean(0)))
acc = (np.array(gt_valid) == np.array(pred_valid)).mean() acc = (np.array(gt_valid) == np.array(pred_valid)).mean()
res = {} res = {}
res['num_valid_files'] = len(gt_valid) res["num_valid_files"] = len(gt_valid)
res['num_total_files'] = len(gts) res["num_total_files"] = len(gts)
res['gt_valid_file'] = gt_valid res["gt_valid_file"] = gt_valid
res['pred_valid_file'] = pred_valid res["pred_valid_file"] = pred_valid
res['file_acc'] = np.round(acc, 5) res["file_acc"] = np.round(acc, 5)
return res return res
@ -256,12 +299,20 @@ def compute_file_accuracy(gts, preds, num_classes):
# compute min and max scoring range - then threshold # compute min and max scoring range - then threshold
min_val = 0 min_val = 0
mins = [pp['class_probs'].min() for pp in preds if pp['class_probs'].shape[1] > 0] mins = [
pp["class_probs"].min()
for pp in preds
if pp["class_probs"].shape[1] > 0
]
if len(mins) > 0: if len(mins) > 0:
min_val = np.min(mins) min_val = np.min(mins)
max_val = 1.0 max_val = 1.0
maxes = [pp['class_probs'].max() for pp in preds if pp['class_probs'].shape[1] > 0] maxes = [
pp["class_probs"].max()
for pp in preds
if pp["class_probs"].shape[1] > 0
]
if len(maxes) > 0: if len(maxes) > 0:
max_val = np.max(maxes) max_val = np.max(maxes)
@ -272,33 +323,37 @@ def compute_file_accuracy(gts, preds, num_classes):
gt_valid = [] gt_valid = []
pred_valid_all = [] pred_valid_all = []
for ii in range(len(gts)): for ii in range(len(gts)):
gt_class = np.unique(gts[ii]['class_ids']) gt_class = np.unique(gts[ii]["class_ids"])
if len(gt_class) == 1 and gt_class[0] != -1: if len(gt_class) == 1 and gt_class[0] != -1:
gt_valid.append(gt_class[0]) gt_valid.append(gt_class[0])
pred = preds[ii]['class_probs'][:num_classes, :].T pred = preds[ii]["class_probs"][:num_classes, :].T
p_class = np.zeros(len(thresh)) p_class = np.zeros(len(thresh))
for tt in range(len(thresh)): for tt in range(len(thresh)):
p_class[tt] = (pred*(pred>=thresh[tt])).sum(0).argmax() p_class[tt] = (pred * (pred >= thresh[tt])).sum(0).argmax()
pred_valid_all.append(p_class) pred_valid_all.append(p_class)
# pick the result corresponding to the overall best threshold # pick the result corresponding to the overall best threshold
pred_valid_all = np.vstack(pred_valid_all) pred_valid_all = np.vstack(pred_valid_all)
acc_per_thresh = (np.array(gt_valid)[..., np.newaxis] == pred_valid_all).mean(0) acc_per_thresh = (
np.array(gt_valid)[..., np.newaxis] == pred_valid_all
).mean(0)
best_thresh = np.argmax(acc_per_thresh) best_thresh = np.argmax(acc_per_thresh)
best_acc = acc_per_thresh[best_thresh] best_acc = acc_per_thresh[best_thresh]
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist() pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
res = {} res = {}
res['num_valid_files'] = len(gt_valid) res["num_valid_files"] = len(gt_valid)
res['num_total_files'] = len(gts) res["num_total_files"] = len(gts)
res['gt_valid_file'] = gt_valid res["gt_valid_file"] = gt_valid
res['pred_valid_file'] = pred_valid res["pred_valid_file"] = pred_valid
res['file_acc'] = np.round(best_acc, 5) res["file_acc"] = np.round(best_acc, 5)
return res return res
def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_start_end=0.0): def evaluate_predictions(
gts, preds, class_names, detection_overlap, ignore_start_end=0.0
):
""" """
Computes metrics derived from the precision and recall. Computes metrics derived from the precision and recall.
Assumes that gts and preds are both lists of the same lengths, with ground Assumes that gts and preds are both lists of the same lengths, with ground
@ -307,24 +362,50 @@ def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_star
Returns the overall detection results, and per class results Returns the overall detection results, and per class results
""" """
assert(len(gts) == len(preds)) assert len(gts) == len(preds)
num_classes = len(class_names) num_classes = len(class_names)
# evaluate detection on its own i.e. ignoring class # evaluate detection on its own i.e. ignoring class
det_results = compute_pre_rec(gts, preds, 'detection', None, num_classes, detection_overlap, ignore_start_end) det_results = compute_pre_rec(
top_class = compute_pre_rec(gts, preds, 'top_class', None, num_classes, detection_overlap, ignore_start_end) gts,
det_results['top_class'] = top_class preds,
"detection",
None,
num_classes,
detection_overlap,
ignore_start_end,
)
top_class = compute_pre_rec(
gts,
preds,
"top_class",
None,
num_classes,
detection_overlap,
ignore_start_end,
)
det_results["top_class"] = top_class
# per class evaluation # per class evaluation
det_results['class_pr'] = [] det_results["class_pr"] = []
for cc in range(num_classes): for cc in range(num_classes):
res = compute_pre_rec(gts, preds, 'per_class', cc, num_classes, detection_overlap, ignore_start_end) res = compute_pre_rec(
res['name'] = class_names[cc] gts,
det_results['class_pr'].append(res) preds,
"per_class",
cc,
num_classes,
detection_overlap,
ignore_start_end,
)
res["name"] = class_names[cc]
det_results["class_pr"].append(res)
# ignores classes that are not present in the test set # ignores classes that are not present in the test set
det_results['avg_prec_class'] = np.mean([rs['avg_prec'] for rs in det_results['class_pr'] if rs['num_gt'] > 0]) det_results["avg_prec_class"] = np.mean(
det_results['avg_prec_class'] = np.round(det_results['avg_prec_class'], 5) [rs["avg_prec"] for rs in det_results["class_pr"] if rs["num_gt"] > 0]
)
det_results["avg_prec_class"] = np.round(det_results["avg_prec_class"], 5)
# file level evaluation # file level evaluation
res_file = compute_file_accuracy(gts, preds, num_classes) res_file = compute_file_accuracy(gts, preds, num_classes)

View File

@ -7,7 +7,9 @@ def bbox_size_loss(pred_size, gt_size):
Bounding box size loss. Only compute loss where there is a bounding box. Bounding box size loss. Only compute loss where there is a bounding box.
""" """
gt_size_mask = (gt_size > 0).float() gt_size_mask = (gt_size > 0).float()
return (F.l1_loss(pred_size*gt_size_mask, gt_size, reduction='sum') / (gt_size_mask.sum() + 1e-5)) return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
gt_size_mask.sum() + 1e-5
)
def focal_loss(pred, gt, weights=None, valid_mask=None): def focal_loss(pred, gt, weights=None, valid_mask=None):
@ -24,20 +26,25 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
neg_inds = gt.lt(1).float() neg_inds = gt.lt(1).float()
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds neg_loss = (
torch.log(1 - pred + eps)
* torch.pow(pred, alpha)
* torch.pow(1 - gt, beta)
* neg_inds
)
if weights is not None: if weights is not None:
pos_loss = pos_loss*weights pos_loss = pos_loss * weights
#neg_loss = neg_loss*weights # neg_loss = neg_loss*weights
if valid_mask is not None: if valid_mask is not None:
pos_loss = pos_loss*valid_mask pos_loss = pos_loss * valid_mask
neg_loss = neg_loss*valid_mask neg_loss = neg_loss * valid_mask
pos_loss = pos_loss.sum() pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum() neg_loss = neg_loss.sum()
num_pos = pos_inds.float().sum() num_pos = pos_inds.float().sum()
if num_pos == 0: if num_pos == 0:
loss = -neg_loss loss = -neg_loss
else: else:
@ -47,10 +54,10 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
def mse_loss(pred, gt, weights=None, valid_mask=None): def mse_loss(pred, gt, weights=None, valid_mask=None):
""" """
Mean squared error loss. Mean squared error loss.
""" """
if valid_mask is None: if valid_mask is None:
op = ((gt-pred)**2).mean() op = ((gt - pred) ** 2).mean()
else: else:
op = (valid_mask*((gt-pred)**2)).sum() / valid_mask.sum() op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
return op return op

View File

@ -1,32 +1,27 @@
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import argparse import argparse
import json
import warnings
import sys import matplotlib.pyplot as plt
sys.path.append(os.path.join('..', '..')) import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
import bat_detect.detector.parameters as parameters from bat_detect.detector import models
import bat_detect.detector.models as models from bat_detect.detector import parameters
from bat_detect.train import losses
import bat_detect.detector.post_process as pp import bat_detect.detector.post_process as pp
import bat_detect.utils.plot_utils as pu
import bat_detect.train.audio_dataloader as adl import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.train.train_split as ts import bat_detect.train.train_split as ts
import bat_detect.train.losses as losses import bat_detect.train.train_utils as tu
import bat_detect.utils.plot_utils as pu
import warnings
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
def save_images_batch(model, data_loader, params): def save_images_batch(model, data_loader, params):
print('\nsaving images ...') print("\nsaving images ...")
is_train_state = data_loader.dataset.is_train is_train_state = data_loader.dataset.is_train
data_loader.dataset.is_train = False data_loader.dataset.is_train = False
@ -36,67 +31,112 @@ def save_images_batch(model, data_loader, params):
ind = 0 # first image in each batch ind = 0 # first image in each batch
with torch.no_grad(): with torch.no_grad():
for batch_idx, inputs in enumerate(data_loader): for batch_idx, inputs in enumerate(data_loader):
data = inputs['spec'].to(params['device']) data = inputs["spec"].to(params["device"])
outputs = model(data) outputs = model(data)
spec_viz = inputs['spec_for_viz'].data.cpu().numpy() spec_viz = inputs["spec_for_viz"].data.cpu().numpy()
orig_index = inputs['file_id'][ind] orig_index = inputs["file_id"][ind]
plot_title = data_loader.dataset.data_anns[orig_index]['id'] plot_title = data_loader.dataset.data_anns[orig_index]["id"]
op_file_name = params['op_im_dir_test'] + data_loader.dataset.data_anns[orig_index]['id'] + '.jpg' op_file_name = (
save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title) params["op_im_dir_test"]
+ data_loader.dataset.data_anns[orig_index]["id"]
+ ".jpg"
)
save_image(
spec_viz,
outputs,
ind,
inputs,
params,
op_file_name,
plot_title,
)
data_loader.dataset.is_train = is_train_state data_loader.dataset.is_train = is_train_state
data_loader.dataset.return_spec_for_viz = False data_loader.dataset.return_spec_for_viz = False
def save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title): def save_image(
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) spec_viz, outputs, ind, inputs, params, op_file_name, plot_title
pred_hm = outputs['pred_det'][ind, 0, :].data.cpu().numpy() ):
pred_nms, _ = pp.run_nms(outputs, params, inputs["sampling_rate"].float())
pred_hm = outputs["pred_det"][ind, 0, :].data.cpu().numpy()
spec_viz = spec_viz[ind, 0, :] spec_viz = spec_viz[ind, 0, :]
gt = parse_gt_data(inputs)[ind] gt = parse_gt_data(inputs)[ind]
sampling_rate = inputs['sampling_rate'][ind].item() sampling_rate = inputs["sampling_rate"][ind].item()
duration = inputs['duration'][ind].item() duration = inputs["duration"][ind].item()
pu.plot_spec(spec_viz, sampling_rate, duration, gt, pred_nms[ind], pu.plot_spec(
params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False) spec_viz,
sampling_rate,
duration,
gt,
pred_nms[ind],
params,
plot_title,
op_file_name,
pred_hm,
plot_boxes=True,
fixed_aspect=False,
)
def loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq): def loss_fun(
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
):
# detection loss # detection loss
loss = params['det_loss_weight']*det_criterion(outputs['pred_det'], gt_det) loss = params["det_loss_weight"] * det_criterion(
outputs["pred_det"], gt_det
)
# bounding box size loss # bounding box size loss
loss += params['size_loss_weight']*losses.bbox_size_loss(outputs['pred_size'], gt_size) loss += params["size_loss_weight"] * losses.bbox_size_loss(
outputs["pred_size"], gt_size
)
# classification loss # classification loss
valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1) valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1)
p_class = outputs['pred_class'][:, :-1, :] p_class = outputs["pred_class"][:, :-1, :]
loss += params['class_loss_weight']*det_criterion(p_class, gt_class[:, :-1, :], valid_mask=valid_mask) loss += params["class_loss_weight"] * det_criterion(
p_class, gt_class[:, :-1, :], valid_mask=valid_mask
)
return loss return loss
def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params): def train(
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
):
model.train() model.train()
train_loss = tu.AverageMeter() train_loss = tu.AverageMeter()
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) class_inv_freq = torch.from_numpy(
np.array(params["class_inv_freq"], dtype=np.float32)
).to(params["device"])
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
print('\nEpoch', epoch) print("\nEpoch", epoch)
for batch_idx, inputs in enumerate(data_loader): for batch_idx, inputs in enumerate(data_loader):
data = inputs['spec'].to(params['device']) data = inputs["spec"].to(params["device"])
gt_det = inputs['y_2d_det'].to(params['device']) gt_det = inputs["y_2d_det"].to(params["device"])
gt_size = inputs['y_2d_size'].to(params['device']) gt_size = inputs["y_2d_size"].to(params["device"])
gt_class = inputs['y_2d_classes'].to(params['device']) gt_class = inputs["y_2d_classes"].to(params["device"])
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(data) outputs = model(data)
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) loss = loss_fun(
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
class_inv_freq,
)
train_loss.update(loss.item(), data.shape[0]) train_loss.update(loss.item(), data.shape[0])
loss.backward() loss.backward()
@ -104,13 +144,18 @@ def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params
scheduler.step() scheduler.step()
if batch_idx % 50 == 0 and batch_idx != 0: if batch_idx % 50 == 0 and batch_idx != 0:
print('[{}/{}]\tLoss: {:.4f}'.format( print(
batch_idx * len(data), len(data_loader.dataset), train_loss.avg)) "[{}/{}]\tLoss: {:.4f}".format(
batch_idx * len(data),
len(data_loader.dataset),
train_loss.avg,
)
)
print('Train loss : {:.4f}'.format(train_loss.avg)) print("Train loss : {:.4f}".format(train_loss.avg))
res = {} res = {}
res['train_loss'] = float(train_loss.avg) res["train_loss"] = float(train_loss.avg)
return res return res
@ -120,16 +165,18 @@ def test(model, epoch, data_loader, det_criterion, params):
ground_truths = [] ground_truths = []
test_loss = tu.AverageMeter() test_loss = tu.AverageMeter()
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) class_inv_freq = torch.from_numpy(
np.array(params["class_inv_freq"], dtype=np.float32)
).to(params["device"])
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
with torch.no_grad(): with torch.no_grad():
for batch_idx, inputs in enumerate(data_loader): for batch_idx, inputs in enumerate(data_loader):
data = inputs['spec'].to(params['device']) data = inputs["spec"].to(params["device"])
gt_det = inputs['y_2d_det'].to(params['device']) gt_det = inputs["y_2d_det"].to(params["device"])
gt_size = inputs['y_2d_size'].to(params['device']) gt_size = inputs["y_2d_size"].to(params["device"])
gt_class = inputs['y_2d_classes'].to(params['device']) gt_class = inputs["y_2d_classes"].to(params["device"])
outputs = model(data) outputs = model(data)
@ -139,41 +186,79 @@ def test(model, epoch, data_loader, det_criterion, params):
# for kk in ['pred_det', 'pred_size', 'pred_class']: # for kk in ['pred_det', 'pred_size', 'pred_class']:
# outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0) # outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0)
if params['save_test_image_during_train'] and batch_idx == 0: if params["save_test_image_during_train"] and batch_idx == 0:
# for visualization - save the first prediction # for visualization - save the first prediction
ind = 0 ind = 0
orig_index = inputs['file_id'][ind] orig_index = inputs["file_id"][ind]
plot_title = data_loader.dataset.data_anns[orig_index]['id'] plot_title = data_loader.dataset.data_anns[orig_index]["id"]
op_file_name = params['op_im_dir'] + str(orig_index.item()).zfill(4) + '_' + str(epoch).zfill(4) + '_pred.jpg' op_file_name = (
save_image(data, outputs, ind, inputs, params, op_file_name, plot_title) params["op_im_dir"]
+ str(orig_index.item()).zfill(4)
+ "_"
+ str(epoch).zfill(4)
+ "_pred.jpg"
)
save_image(
data,
outputs,
ind,
inputs,
params,
op_file_name,
plot_title,
)
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) loss = loss_fun(
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
class_inv_freq,
)
test_loss.update(loss.item(), data.shape[0]) test_loss.update(loss.item(), data.shape[0])
# do NMS # do NMS
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) pred_nms, _ = pp.run_nms(
outputs, params, inputs["sampling_rate"].float()
)
predictions.extend(pred_nms) predictions.extend(pred_nms)
ground_truths.extend(parse_gt_data(inputs)) ground_truths.extend(parse_gt_data(inputs))
res_det = evl.evaluate_predictions(ground_truths, predictions, params['class_names'], res_det = evl.evaluate_predictions(
params['detection_overlap'], params['ignore_start_end']) ground_truths,
predictions,
params["class_names"],
params["detection_overlap"],
params["ignore_start_end"],
)
print('\nTest loss : {:.4f}'.format(test_loss.avg)) print("\nTest loss : {:.4f}".format(test_loss.avg))
print('Rec at 0.95 (det) : {:.4f}'.format(res_det['rec_at_x'])) print("Rec at 0.95 (det) : {:.4f}".format(res_det["rec_at_x"]))
print('Avg prec (cls) : {:.4f}'.format(res_det['avg_prec'])) print("Avg prec (cls) : {:.4f}".format(res_det["avg_prec"]))
print('File acc (cls) : {:.2f} - for {} out of {}'.format(res_det['file_acc'], print(
res_det['num_valid_files'], res_det['num_total_files'])) "File acc (cls) : {:.2f} - for {} out of {}".format(
print('Cls Avg prec (cls) : {:.4f}'.format(res_det['avg_prec_class'])) res_det["file_acc"],
res_det["num_valid_files"],
res_det["num_total_files"],
)
)
print("Cls Avg prec (cls) : {:.4f}".format(res_det["avg_prec_class"]))
print('\nPer class average precision') print("\nPer class average precision")
str_len = np.max([len(rs['name']) for rs in res_det['class_pr']]) + 5 str_len = np.max([len(rs["name"]) for rs in res_det["class_pr"]]) + 5
for cc, rs in enumerate(res_det['class_pr']): for cc, rs in enumerate(res_det["class_pr"]):
if rs['num_gt'] > 0: if rs["num_gt"] > 0:
print(str(cc).ljust(5) + rs['name'].ljust(str_len) + '{:.4f}'.format(rs['avg_prec'])) print(
str(cc).ljust(5)
+ rs["name"].ljust(str_len)
+ "{:.4f}".format(rs["avg_prec"])
)
res = {} res = {}
res['test_loss'] = float(test_loss.avg) res["test_loss"] = float(test_loss.avg)
return res_det, res return res_det, res
@ -181,176 +266,287 @@ def test(model, epoch, data_loader, det_criterion, params):
def parse_gt_data(inputs): def parse_gt_data(inputs):
# reads the torch arrays into a dictionary of numpy arrays, taking care to # reads the torch arrays into a dictionary of numpy arrays, taking care to
# remove padding data i.e. not valid ones # remove padding data i.e. not valid ones
keys = ['start_times', 'end_times', 'low_freqs', 'high_freqs', 'class_ids', 'individual_ids'] keys = [
"start_times",
"end_times",
"low_freqs",
"high_freqs",
"class_ids",
"individual_ids",
]
batch_data = [] batch_data = []
for ind in range(inputs['start_times'].shape[0]): for ind in range(inputs["start_times"].shape[0]):
is_valid = inputs['is_valid'][ind]==1 is_valid = inputs["is_valid"][ind] == 1
gt = {} gt = {}
for kk in keys: for kk in keys:
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
gt['duration'] = inputs['duration'][ind].item() gt["duration"] = inputs["duration"][ind].item()
gt['file_id'] = inputs['file_id'][ind].item() gt["file_id"] = inputs["file_id"][ind].item()
gt['class_id_file'] = inputs['class_id_file'][ind].item() gt["class_id_file"] = inputs["class_id_file"][ind].item()
batch_data.append(gt) batch_data.append(gt)
return batch_data return batch_data
def select_model(params): def select_model(params):
num_classes = len(params['class_names']) num_classes = len(params["class_names"])
if params['model_name'] == 'Net2DFast': if params["model_name"] == "Net2DFast":
model = models.Net2DFast(params['num_filters'], num_classes=num_classes, model = models.Net2DFast(
emb_dim=params['emb_dim'], ip_height=params['ip_height'], params["num_filters"],
resize_factor=params['resize_factor']) num_classes=num_classes,
elif params['model_name'] == 'Net2DFastNoAttn': emb_dim=params["emb_dim"],
model = models.Net2DFastNoAttn(params['num_filters'], num_classes=num_classes, ip_height=params["ip_height"],
emb_dim=params['emb_dim'], ip_height=params['ip_height'], resize_factor=params["resize_factor"],
resize_factor=params['resize_factor']) )
elif params['model_name'] == 'Net2DFastNoCoordConv': elif params["model_name"] == "Net2DFastNoAttn":
model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=num_classes, model = models.Net2DFastNoAttn(
emb_dim=params['emb_dim'], ip_height=params['ip_height'], params["num_filters"],
resize_factor=params['resize_factor']) num_classes=num_classes,
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
elif params["model_name"] == "Net2DFastNoCoordConv":
model = models.Net2DFastNoCoordConv(
params["num_filters"],
num_classes=num_classes,
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
else: else:
print('No valid network specified') print("No valid network specified")
return model return model
if __name__ == "__main__": if __name__ == "__main__":
plt.close('all') plt.close("all")
params = parameters.get_params(True) params = parameters.get_params(True)
if torch.cuda.is_available(): if torch.cuda.is_available():
params['device'] = 'cuda' params["device"] = "cuda"
else: else:
params['device'] = 'cpu' params["device"] = "cpu"
# setup arg parser and populate it with exiting parameters - will not work with lists # setup arg parser and populate it with exiting parameters - will not work with lists
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('data_dir', type=str, parser.add_argument("data_dir", type=str, help="Path to root of datasets")
help='Path to root of datasets') parser.add_argument(
parser.add_argument('ann_dir', type=str, "ann_dir", type=str, help="Path to extracted annotations"
help='Path to extracted annotations') )
parser.add_argument('--train_split', type=str, default='diff', # diff, same parser.add_argument(
help='Which train split to use') "--train_split",
parser.add_argument('--notes', type=str, default='', type=str,
help='Notes to save in text file') default="diff", # diff, same
parser.add_argument('--do_not_save_images', action='store_false', help="Which train split to use",
help='Do not save images at the end of training') )
parser.add_argument('--standardize_classs_names_ip', type=str, parser.add_argument(
default='Rhinolophus ferrumequinum;Rhinolophus hipposideros', "--notes", type=str, default="", help="Notes to save in text file"
help='Will set low and high frequency the same for these classes. Separate names with ";"') )
parser.add_argument(
"--do_not_save_images",
action="store_false",
help="Do not save images at the end of training",
)
parser.add_argument(
"--standardize_classs_names_ip",
type=str,
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
help='Will set low and high frequency the same for these classes. Separate names with ";"',
)
for key, val in params.items(): for key, val in params.items():
parser.add_argument('--'+key, type=type(val), default=val) parser.add_argument("--" + key, type=type(val), default=val)
params = vars(parser.parse_args()) params = vars(parser.parse_args())
# save notes file # save notes file
if params['notes'] != '': if params["notes"] != "":
tu.write_notes_file(params['experiment'] + 'notes.txt', params['notes']) tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"])
# load the training and test meta data - there are different splits defined # load the training and test meta data - there are different splits defined
train_sets, test_sets = ts.get_train_test_data(params['ann_dir'], params['data_dir'], params['train_split']) train_sets, test_sets = ts.get_train_test_data(
train_sets_no_path, test_sets_no_path = ts.get_train_test_data('', '', params['train_split']) params["ann_dir"], params["data_dir"], params["train_split"]
)
train_sets_no_path, test_sets_no_path = ts.get_train_test_data(
"", "", params["train_split"]
)
# keep track of what we have trained on # keep track of what we have trained on
params['train_sets'] = train_sets_no_path params["train_sets"] = train_sets_no_path
params['test_sets'] = test_sets_no_path params["test_sets"] = test_sets_no_path
# load train annotations - merge them all together # load train annotations - merge them all together
print('\nTraining on:') print("\nTraining on:")
for tt in train_sets: for tt in train_sets:
print(tt['ann_path']) print(tt["ann_path"])
classes_to_ignore = params['classes_to_ignore']+params['generic_class'] classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
data_train, params['class_names'], params['class_inv_freq'] = \ (
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) data_train,
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names']) params["class_names"],
params['class_names_short'] = tu.get_short_class_names(params['class_names']) params["class_inv_freq"],
) = tu.load_set_of_anns(
train_sets,
classes_to_ignore,
params["events_of_interest"],
params["convert_to_genus"],
)
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params["class_names"]
)
params["class_names_short"] = tu.get_short_class_names(
params["class_names"]
)
# standardize the low and high frequency value for specified classes # standardize the low and high frequency value for specified classes
params['standardize_classs_names'] = params['standardize_classs_names_ip'].split(';') params["standardize_classs_names"] = params[
for cc in params['standardize_classs_names']: "standardize_classs_names_ip"
if cc in params['class_names']: ].split(";")
for cc in params["standardize_classs_names"]:
if cc in params["class_names"]:
data_train = tu.standardize_low_freq(data_train, cc) data_train = tu.standardize_low_freq(data_train, cc)
else: else:
print(cc, 'not found') print(cc, "not found")
# train loader # train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True) train_dataset = adl.AudioLoader(data_train, params, is_train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'], train_loader = torch.utils.data.DataLoader(
shuffle=True, num_workers=params['num_workers'], pin_memory=True) train_dataset,
batch_size=params["batch_size"],
shuffle=True,
num_workers=params["num_workers"],
pin_memory=True,
)
# test set # test set
print('\nTesting on:') print("\nTesting on:")
for tt in test_sets: for tt in test_sets:
print(tt['ann_path']) print(tt["ann_path"])
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) data_test, _, _ = tu.load_set_of_anns(
test_sets,
classes_to_ignore,
params["events_of_interest"],
params["convert_to_genus"],
)
data_train = tu.remove_dupes(data_train, data_test) data_train = tu.remove_dupes(data_train, data_test)
test_dataset = adl.AudioLoader(data_test, params, is_train=False) test_dataset = adl.AudioLoader(data_test, params, is_train=False)
# batch size of 1 because of variable file length # batch size of 1 because of variable file length
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, test_loader = torch.utils.data.DataLoader(
shuffle=False, num_workers=params['num_workers'], pin_memory=True) test_dataset,
batch_size=1,
shuffle=False,
num_workers=params["num_workers"],
pin_memory=True,
)
inputs_train = next(iter(train_loader)) inputs_train = next(iter(train_loader))
# TODO remove params['ip_height'], this is just legacy # TODO remove params['ip_height'], this is just legacy
params['ip_height'] = int(params['spec_height']*params['resize_factor']) params["ip_height"] = int(params["spec_height"] * params["resize_factor"])
print('\ntrain batch spec size :', inputs_train['spec'].shape) print("\ntrain batch spec size :", inputs_train["spec"].shape)
print('class target size :', inputs_train['y_2d_classes'].shape) print("class target size :", inputs_train["y_2d_classes"].shape)
# select network # select network
model = select_model(params) model = select_model(params)
model = model.to(params['device']) model = model.to(params["device"])
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
#optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9) # optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader)) scheduler = CosineAnnealingLR(
if params['train_loss'] == 'mse': optimizer, params["num_epochs"] * len(train_loader)
)
if params["train_loss"] == "mse":
det_criterion = losses.mse_loss det_criterion = losses.mse_loss
elif params['train_loss'] == 'focal': elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss det_criterion = losses.focal_loss
# save parameters to file # save parameters to file
with open(params['experiment'] + 'params.json', 'w') as da: with open(params["experiment"] + "params.json", "w") as da:
json.dump(params, da, indent=2, sort_keys=True) json.dump(params, da, indent=2, sort_keys=True)
# plotting # plotting
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1, train_plt_ls = pu.LossPlotter(
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True) params["experiment"] + "train_loss.png",
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1, params["num_epochs"] + 1,
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True) ["train_loss"],
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1, None,
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', '']) None,
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1, ["epoch", "train_loss"],
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec']) logy=True,
)
test_plt_ls = pu.LossPlotter(
params["experiment"] + "test_loss.png",
params["num_epochs"] + 1,
["test_loss"],
None,
None,
["epoch", "test_loss"],
logy=True,
)
test_plt = pu.LossPlotter(
params["experiment"] + "test.png",
params["num_epochs"] + 1,
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
[0, 1],
None,
["epoch", ""],
)
test_plt_class = pu.LossPlotter(
params["experiment"] + "test_avg_prec.png",
params["num_epochs"] + 1,
params["class_names_short"],
[0, 1],
params["class_names_short"],
["epoch", "avg_prec"],
)
# #
# main train loop # main train loop
for epoch in range(0, params['num_epochs']+1): for epoch in range(0, params["num_epochs"] + 1):
train_loss = train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params) train_loss = train(
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']]) model,
epoch,
train_loader,
det_criterion,
optimizer,
scheduler,
params,
)
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
if epoch % params['num_eval_epochs'] == 0: if epoch % params["num_eval_epochs"] == 0:
# detection accuracy on test set # detection accuracy on test set
test_res, test_loss = test(model, epoch, test_loader, det_criterion, params) test_res, test_loss = test(
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']]) model, epoch, test_loader, det_criterion, params
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'], )
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']]) test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']]) test_plt.update_and_save(
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res) epoch,
[
test_res["avg_prec"],
test_res["rec_at_x"],
test_res["avg_prec_class"],
test_res["file_acc"],
test_res["top_class"]["avg_prec"],
],
)
test_plt_class.update_and_save(
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
)
pu.plot_pr_curve_class(
params["experiment"], "test_pr", "test_pr", test_res
)
# save trained model # save trained model
print('saving model to: ' + params['model_file_name']) print("saving model to: " + params["model_file_name"])
op_state = {'epoch': epoch + 1, op_state = {
'state_dict': model.state_dict(), "epoch": epoch + 1,
#'optimizer' : optimizer.state_dict(), "state_dict": model.state_dict(),
'params' : params} #'optimizer' : optimizer.state_dict(),
torch.save(op_state, params['model_file_name']) "params": params,
}
torch.save(op_state, params["model_file_name"])
# save an image with associated prediction for each batch in the test set # save an image with associated prediction for each batch in the test set
if not args['do_not_save_images']: # TODO: args variable does not exist
save_images_batch(model, test_loader, params) # if not args["do_not_save_images"]:
# save_images_batch(model, test_loader, params)

View File

@ -2,13 +2,14 @@
Run scripts/extract_anns.py to generate these json files. Run scripts/extract_anns.py to generate these json files.
""" """
def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True): def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
if split_name == 'diff': if split_name == "diff":
train_sets, test_sets = split_diff(ann_dir, wav_dir, load_extra) train_sets, test_sets = split_diff(ann_dir, wav_dir, load_extra)
elif split_name == 'same': elif split_name == "same":
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra) train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
else: else:
print('Split not defined') print("Split not defined")
assert False assert False
return train_sets, test_sets return train_sets, test_sets
@ -18,73 +19,126 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append({'dataset_name': 'BatDetective', train_sets.append(
'is_test': False, {
'is_binary': True, # just a bat / not bat dataset ie no classes "dataset_name": "BatDetective",
'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json', "is_test": False,
'wav_path': wav_dir + 'bat_detective/audio/'}) "is_binary": True, # just a bat / not bat dataset ie no classes
train_sets.append({'dataset_name': 'bat_logger_qeop_empty', "ann_path": ann_dir
'is_test': False, + "train_set_bulgaria_batdetective_with_bbs.json",
'is_binary': True, "wav_path": wav_dir + "bat_detective/audio/",
'ann_path': ann_dir + 'bat_logger_qeop_empty.json', }
'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'}) )
train_sets.append({'dataset_name': 'bat_logger_2016_empty', train_sets.append(
'is_test': False, {
'is_binary': True, "dataset_name": "bat_logger_qeop_empty",
'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json', "is_test": False,
'wav_path': wav_dir + 'bat_logger_2016/audio/'}) "is_binary": True,
"ann_path": ann_dir + "bat_logger_qeop_empty.json",
"wav_path": wav_dir + "bat_logger_qeop_empty/audio/",
}
)
train_sets.append(
{
"dataset_name": "bat_logger_2016_empty",
"is_test": False,
"is_binary": True,
"ann_path": ann_dir + "train_set_bat_logger_2016_empty.json",
"wav_path": wav_dir + "bat_logger_2016/audio/",
}
)
# train_sets.append({'dataset_name': 'brazil_data_binary', # train_sets.append({'dataset_name': 'brazil_data_binary',
# 'is_test': False, # 'is_test': False,
# 'ann_path': ann_dir + 'brazil_data_binary.json', # 'ann_path': ann_dir + 'brazil_data_binary.json',
# 'wav_path': wav_dir + 'brazil_data/audio/'}) # 'wav_path': wav_dir + 'brazil_data/audio/'})
train_sets.append({'dataset_name': 'echobank', train_sets.append(
'is_test': False, {
'is_binary': False, "dataset_name": "echobank",
'ann_path': ann_dir + 'Echobank_train_expert.json', "is_test": False,
'wav_path': wav_dir + 'echobank/audio/'}) "is_binary": False,
train_sets.append({'dataset_name': 'sn_scot_nor', "ann_path": ann_dir + "Echobank_train_expert.json",
'is_test': False, "wav_path": wav_dir + "echobank/audio/",
'is_binary': False, }
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert.json', )
'wav_path': wav_dir + 'sn_scot_nor/audio/'}) train_sets.append(
train_sets.append({'dataset_name': 'BCT_1_sec', {
'is_test': False, "dataset_name": "sn_scot_nor",
'is_binary': False, "is_test": False,
'ann_path': ann_dir + 'BCT_1_sec_train_expert.json', "is_binary": False,
'wav_path': wav_dir + 'BCT_1_sec/audio/'}) "ann_path": ann_dir + "sn_scot_nor_0.5_expert.json",
train_sets.append({'dataset_name': 'bcireland', "wav_path": wav_dir + "sn_scot_nor/audio/",
'is_test': False, }
'is_binary': False, )
'ann_path': ann_dir + 'bcireland_expert.json', train_sets.append(
'wav_path': wav_dir + 'bcireland/audio/'}) {
train_sets.append({'dataset_name': 'rhinolophus_steve_BCT', "dataset_name": "BCT_1_sec",
'is_test': False, "is_test": False,
'is_binary': False, "is_binary": False,
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert.json', "ann_path": ann_dir + "BCT_1_sec_train_expert.json",
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'}) "wav_path": wav_dir + "BCT_1_sec/audio/",
}
)
train_sets.append(
{
"dataset_name": "bcireland",
"is_test": False,
"is_binary": False,
"ann_path": ann_dir + "bcireland_expert.json",
"wav_path": wav_dir + "bcireland/audio/",
}
)
train_sets.append(
{
"dataset_name": "rhinolophus_steve_BCT",
"is_test": False,
"is_binary": False,
"ann_path": ann_dir + "rhinolophus_steve_BCT_expert.json",
"wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
}
)
test_sets = [] test_sets = []
test_sets.append({'dataset_name': 'bat_data_martyn_2018', test_sets.append(
'is_test': True, {
'is_binary': False, "dataset_name": "bat_data_martyn_2018",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json', "is_test": True,
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'}) "is_binary": False,
test_sets.append({'dataset_name': 'bat_data_martyn_2018_test', "ann_path": ann_dir
'is_test': True, + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json",
'is_binary': False, "wav_path": wav_dir + "bat_data_martyn_2018/audio/",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json', }
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'}) )
test_sets.append({'dataset_name': 'bat_data_martyn_2019', test_sets.append(
'is_test': True, {
'is_binary': False, "dataset_name": "bat_data_martyn_2018_test",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json', "is_test": True,
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'}) "is_binary": False,
test_sets.append({'dataset_name': 'bat_data_martyn_2019_test', "ann_path": ann_dir
'is_test': True, + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json",
'is_binary': False, "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json', }
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'}) )
test_sets.append(
{
"dataset_name": "bat_data_martyn_2019",
"is_test": True,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json",
"wav_path": wav_dir + "bat_data_martyn_2019/audio/",
}
)
test_sets.append(
{
"dataset_name": "bat_data_martyn_2019_test",
"is_test": True,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json",
"wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
}
)
return train_sets, test_sets return train_sets, test_sets
@ -93,71 +147,124 @@ def split_same(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append({'dataset_name': 'BatDetective', train_sets.append(
'is_test': False, {
'is_binary': True, "dataset_name": "BatDetective",
'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json', "is_test": False,
'wav_path': wav_dir + 'bat_detective/audio/'}) "is_binary": True,
train_sets.append({'dataset_name': 'bat_logger_qeop_empty', "ann_path": ann_dir
'is_test': False, + "train_set_bulgaria_batdetective_with_bbs.json",
'is_binary': True, "wav_path": wav_dir + "bat_detective/audio/",
'ann_path': ann_dir + 'bat_logger_qeop_empty.json', }
'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'}) )
train_sets.append({'dataset_name': 'bat_logger_2016_empty', train_sets.append(
'is_test': False, {
'is_binary': True, "dataset_name": "bat_logger_qeop_empty",
'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json', "is_test": False,
'wav_path': wav_dir + 'bat_logger_2016/audio/'}) "is_binary": True,
"ann_path": ann_dir + "bat_logger_qeop_empty.json",
"wav_path": wav_dir + "bat_logger_qeop_empty/audio/",
}
)
train_sets.append(
{
"dataset_name": "bat_logger_2016_empty",
"is_test": False,
"is_binary": True,
"ann_path": ann_dir + "train_set_bat_logger_2016_empty.json",
"wav_path": wav_dir + "bat_logger_2016/audio/",
}
)
# train_sets.append({'dataset_name': 'brazil_data_binary', # train_sets.append({'dataset_name': 'brazil_data_binary',
# 'is_test': False, # 'is_test': False,
# 'ann_path': ann_dir + 'brazil_data_binary.json', # 'ann_path': ann_dir + 'brazil_data_binary.json',
# 'wav_path': wav_dir + 'brazil_data/audio/'}) # 'wav_path': wav_dir + 'brazil_data/audio/'})
train_sets.append({'dataset_name': 'echobank', train_sets.append(
'is_test': False, {
'is_binary': False, "dataset_name": "echobank",
'ann_path': ann_dir + 'Echobank_train_expert_TRAIN.json', "is_test": False,
'wav_path': wav_dir + 'echobank/audio/'}) "is_binary": False,
train_sets.append({'dataset_name': 'sn_scot_nor', "ann_path": ann_dir + "Echobank_train_expert_TRAIN.json",
'is_test': False, "wav_path": wav_dir + "echobank/audio/",
'is_binary': False, }
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TRAIN.json', )
'wav_path': wav_dir + 'sn_scot_nor/audio/'}) train_sets.append(
train_sets.append({'dataset_name': 'BCT_1_sec', {
'is_test': False, "dataset_name": "sn_scot_nor",
'is_binary': False, "is_test": False,
'ann_path': ann_dir + 'BCT_1_sec_train_expert_TRAIN.json', "is_binary": False,
'wav_path': wav_dir + 'BCT_1_sec/audio/'}) "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TRAIN.json",
train_sets.append({'dataset_name': 'bcireland', "wav_path": wav_dir + "sn_scot_nor/audio/",
'is_test': False, }
'is_binary': False, )
'ann_path': ann_dir + 'bcireland_expert_TRAIN.json', train_sets.append(
'wav_path': wav_dir + 'bcireland/audio/'}) {
train_sets.append({'dataset_name': 'rhinolophus_steve_BCT', "dataset_name": "BCT_1_sec",
'is_test': False, "is_test": False,
'is_binary': False, "is_binary": False,
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TRAIN.json', "ann_path": ann_dir + "BCT_1_sec_train_expert_TRAIN.json",
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'}) "wav_path": wav_dir + "BCT_1_sec/audio/",
train_sets.append({'dataset_name': 'bat_data_martyn_2018', }
'is_test': False, )
'is_binary': False, train_sets.append(
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json', {
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'}) "dataset_name": "bcireland",
train_sets.append({'dataset_name': 'bat_data_martyn_2018_test', "is_test": False,
'is_test': False, "is_binary": False,
'is_binary': False, "ann_path": ann_dir + "bcireland_expert_TRAIN.json",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json', "wav_path": wav_dir + "bcireland/audio/",
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'}) }
train_sets.append({'dataset_name': 'bat_data_martyn_2019', )
'is_test': False, train_sets.append(
'is_binary': False, {
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json', "dataset_name": "rhinolophus_steve_BCT",
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'}) "is_test": False,
train_sets.append({'dataset_name': 'bat_data_martyn_2019_test', "is_binary": False,
'is_test': False, "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TRAIN.json",
'is_binary': False, "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json', }
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'}) )
train_sets.append(
{
"dataset_name": "bat_data_martyn_2018",
"is_test": False,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json",
"wav_path": wav_dir + "bat_data_martyn_2018/audio/",
}
)
train_sets.append(
{
"dataset_name": "bat_data_martyn_2018_test",
"is_test": False,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json",
"wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
}
)
train_sets.append(
{
"dataset_name": "bat_data_martyn_2019",
"is_test": False,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json",
"wav_path": wav_dir + "bat_data_martyn_2019/audio/",
}
)
train_sets.append(
{
"dataset_name": "bat_data_martyn_2019_test",
"is_test": False,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json",
"wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
}
)
# train_sets.append({'dataset_name': 'bat_data_martyn_2021_train', # train_sets.append({'dataset_name': 'bat_data_martyn_2021_train',
# 'is_test': False, # 'is_test': False,
@ -171,51 +278,91 @@ def split_same(ann_dir, wav_dir, load_extra=True):
# 'wav_path': wav_dir + 'volunteers_2021/audio/'}) # 'wav_path': wav_dir + 'volunteers_2021/audio/'})
test_sets = [] test_sets = []
test_sets.append({'dataset_name': 'echobank', test_sets.append(
'is_test': True, {
'is_binary': False, "dataset_name": "echobank",
'ann_path': ann_dir + 'Echobank_train_expert_TEST.json', "is_test": True,
'wav_path': wav_dir + 'echobank/audio/'}) "is_binary": False,
test_sets.append({'dataset_name': 'sn_scot_nor', "ann_path": ann_dir + "Echobank_train_expert_TEST.json",
'is_test': True, "wav_path": wav_dir + "echobank/audio/",
'is_binary': False, }
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TEST.json', )
'wav_path': wav_dir + 'sn_scot_nor/audio/'}) test_sets.append(
test_sets.append({'dataset_name': 'BCT_1_sec', {
'is_test': True, "dataset_name": "sn_scot_nor",
'is_binary': False, "is_test": True,
'ann_path': ann_dir + 'BCT_1_sec_train_expert_TEST.json', "is_binary": False,
'wav_path': wav_dir + 'BCT_1_sec/audio/'}) "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TEST.json",
test_sets.append({'dataset_name': 'bcireland', "wav_path": wav_dir + "sn_scot_nor/audio/",
'is_test': True, }
'is_binary': False, )
'ann_path': ann_dir + 'bcireland_expert_TEST.json', test_sets.append(
'wav_path': wav_dir + 'bcireland/audio/'}) {
test_sets.append({'dataset_name': 'rhinolophus_steve_BCT', "dataset_name": "BCT_1_sec",
'is_test': True, "is_test": True,
'is_binary': False, "is_binary": False,
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TEST.json', "ann_path": ann_dir + "BCT_1_sec_train_expert_TEST.json",
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'}) "wav_path": wav_dir + "BCT_1_sec/audio/",
test_sets.append({'dataset_name': 'bat_data_martyn_2018', }
'is_test': True, )
'is_binary': False, test_sets.append(
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json', {
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'}) "dataset_name": "bcireland",
test_sets.append({'dataset_name': 'bat_data_martyn_2018_test', "is_test": True,
'is_test': True, "is_binary": False,
'is_binary': False, "ann_path": ann_dir + "bcireland_expert_TEST.json",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json', "wav_path": wav_dir + "bcireland/audio/",
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'}) }
test_sets.append({'dataset_name': 'bat_data_martyn_2019', )
'is_test': True, test_sets.append(
'is_binary': False, {
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json', "dataset_name": "rhinolophus_steve_BCT",
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'}) "is_test": True,
test_sets.append({'dataset_name': 'bat_data_martyn_2019_test', "is_binary": False,
'is_test': True, "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TEST.json",
'is_binary': False, "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json', }
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'}) )
test_sets.append(
{
"dataset_name": "bat_data_martyn_2018",
"is_test": True,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json",
"wav_path": wav_dir + "bat_data_martyn_2018/audio/",
}
)
test_sets.append(
{
"dataset_name": "bat_data_martyn_2018_test",
"is_test": True,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json",
"wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
}
)
test_sets.append(
{
"dataset_name": "bat_data_martyn_2019",
"is_test": True,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json",
"wav_path": wav_dir + "bat_data_martyn_2019/audio/",
}
)
test_sets.append(
{
"dataset_name": "bat_data_martyn_2019_test",
"is_test": True,
"is_binary": False,
"ann_path": ann_dir
+ "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json",
"wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
}
)
# test_sets.append({'dataset_name': 'bat_data_martyn_2021_test', # test_sets.append({'dataset_name': 'bat_data_martyn_2021_test',
# 'is_test': True, # 'is_test': True,

View File

@ -1,42 +1,52 @@
import numpy as np
import random
import os
import glob import glob
import json import json
import os
import random
import numpy as np
def write_notes_file(file_name, text): def write_notes_file(file_name, text):
with open(file_name, 'a') as da: with open(file_name, "a") as da:
da.write(text + '\n') da.write(text + "\n")
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path): def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
ddict = {'dataset_name': dataset_name, 'is_test': is_test, 'is_binary': False, ddict = {
'ann_path': ann_path, 'wav_path': wav_path} "dataset_name": dataset_name,
"is_test": is_test,
"is_binary": False,
"ann_path": ann_path,
"wav_path": wav_path,
}
return ddict return ddict
def get_short_class_names(class_names, str_len=3): def get_short_class_names(class_names, str_len=3):
class_names_short = [] class_names_short = []
for cc in class_names: for cc in class_names:
class_names_short.append(' '.join([sp[:str_len] for sp in cc.split(' ')])) class_names_short.append(
" ".join([sp[:str_len] for sp in cc.split(" ")])
)
return class_names_short return class_names_short
def remove_dupes(data_train, data_test): def remove_dupes(data_train, data_test):
test_ids = [dd['id'] for dd in data_test] test_ids = [dd["id"] for dd in data_test]
data_train_prune = [] data_train_prune = []
for aa in data_train: for aa in data_train:
if aa['id'] not in test_ids: if aa["id"] not in test_ids:
data_train_prune.append(aa) data_train_prune.append(aa)
diff = len(data_train) - len(data_train_prune) diff = len(data_train) - len(data_train_prune)
if diff != 0: if diff != 0:
print(diff, 'items removed from train set') print(diff, "items removed from train set")
return data_train_prune return data_train_prune
def get_genus_mapping(class_names): def get_genus_mapping(class_names):
genus_names, genus_mapping = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True) genus_names, genus_mapping = np.unique(
[cc.split(" ")[0] for cc in class_names], return_inverse=True
)
return genus_names.tolist(), genus_mapping.tolist() return genus_names.tolist(), genus_mapping.tolist()
@ -47,97 +57,110 @@ def standardize_low_freq(data, class_of_interest):
low_freqs = [] low_freqs = []
high_freqs = [] high_freqs = []
for dd in data: for dd in data:
for aa in dd['annotation']: for aa in dd["annotation"]:
if aa['class'] == class_of_interest: if aa["class"] == class_of_interest:
low_freqs.append(aa['low_freq']) low_freqs.append(aa["low_freq"])
high_freqs.append(aa['high_freq']) high_freqs.append(aa["high_freq"])
low_mean = np.mean(low_freqs) low_mean = np.mean(low_freqs)
high_mean = np.mean(high_freqs) high_mean = np.mean(high_freqs)
assert(low_mean < high_mean) assert low_mean < high_mean
print('\nStandardizing low and high frequency for:') print("\nStandardizing low and high frequency for:")
print(class_of_interest) print(class_of_interest)
print('low: ', round(low_mean, 2)) print("low: ", round(low_mean, 2))
print('high: ', round(high_mean, 2)) print("high: ", round(high_mean, 2))
# only set the low freq, high stays the same # only set the low freq, high stays the same
# assumes that low_mean < high_mean # assumes that low_mean < high_mean
for dd in data: for dd in data:
for aa in dd['annotation']: for aa in dd["annotation"]:
if aa['class'] == class_of_interest: if aa["class"] == class_of_interest:
aa['low_freq'] = low_mean aa["low_freq"] = low_mean
if aa['high_freq'] < low_mean: if aa["high_freq"] < low_mean:
aa['high_freq'] = high_mean aa["high_freq"] = high_mean
return data return data
def load_set_of_anns(data, classes_to_ignore=[], events_of_interest=None, def load_set_of_anns(
convert_to_genus=False, verbose=True, list_of_anns=False, data,
filter_issues=False, name_replace=False): classes_to_ignore=[],
events_of_interest=None,
convert_to_genus=False,
verbose=True,
list_of_anns=False,
filter_issues=False,
name_replace=False,
):
# load the annotations # load the annotations
anns = [] anns = []
if list_of_anns: if list_of_anns:
# path to list of individual json files # path to list of individual json files
anns.extend(load_anns_from_path(data['ann_path'], data['wav_path'])) anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
else: else:
# dictionary of datasets # dictionary of datasets
for dd in data: for dd in data:
anns.extend(load_anns(dd['ann_path'], dd['wav_path'])) anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
# discarding unannoated files # discarding unannoated files
anns = [aa for aa in anns if aa['annotated'] is True] anns = [aa for aa in anns if aa["annotated"] is True]
# filter files that have annotation issues - is the input is a dictionary of # filter files that have annotation issues - is the input is a dictionary of
# datasets, this will lilely have already been done # datasets, this will lilely have already been done
if filter_issues: if filter_issues:
anns = [aa for aa in anns if aa['issues'] is False] anns = [aa for aa in anns if aa["issues"] is False]
# check for some basic formatting errors with class names # check for some basic formatting errors with class names
for ann in anns: for ann in anns:
for aa in ann['annotation']: for aa in ann["annotation"]:
aa['class'] = aa['class'].strip() aa["class"] = aa["class"].strip()
# only load specified events - i.e. types of calls # only load specified events - i.e. types of calls
if events_of_interest is not None: if events_of_interest is not None:
for ann in anns: for ann in anns:
filtered_events = [] filtered_events = []
for aa in ann['annotation']: for aa in ann["annotation"]:
if aa['event'] in events_of_interest: if aa["event"] in events_of_interest:
filtered_events.append(aa) filtered_events.append(aa)
ann['annotation'] = filtered_events ann["annotation"] = filtered_events
# change class names # change class names
# replace_names will be a dictionary mapping input name to output # replace_names will be a dictionary mapping input name to output
if type(name_replace) is dict: if type(name_replace) is dict:
for ann in anns: for ann in anns:
for aa in ann['annotation']: for aa in ann["annotation"]:
if aa['class'] in name_replace: if aa["class"] in name_replace:
aa['class'] = name_replace[aa['class']] aa["class"] = name_replace[aa["class"]]
# convert everything to genus name # convert everything to genus name
if convert_to_genus: if convert_to_genus:
for ann in anns: for ann in anns:
for aa in ann['annotation']: for aa in ann["annotation"]:
aa['class'] = aa['class'].split(' ')[0] aa["class"] = aa["class"].split(" ")[0]
# get unique class names # get unique class names
class_names_all = [] class_names_all = []
for ann in anns: for ann in anns:
for aa in ann['annotation']: for aa in ann["annotation"]:
if aa['class'] not in classes_to_ignore: if aa["class"] not in classes_to_ignore:
class_names_all.append(aa['class']) class_names_all.append(aa["class"])
class_names, class_cnts = np.unique(class_names_all, return_counts=True) class_names, class_cnts = np.unique(class_names_all, return_counts=True)
class_inv_freq = (class_cnts.sum() / (len(class_names) * class_cnts.astype(np.float32))) class_inv_freq = class_cnts.sum() / (
len(class_names) * class_cnts.astype(np.float32)
)
if verbose: if verbose:
print('Class count:') print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5 str_len = np.max([len(cc) for cc in class_names]) + 5
for cc in range(len(class_names)): for cc in range(len(class_names)):
print(str(cc).ljust(5) + class_names[cc].ljust(str_len) + str(class_cnts[cc])) print(
str(cc).ljust(5)
+ class_names[cc].ljust(str_len)
+ str(class_cnts[cc])
)
if len(classes_to_ignore) == 0: if len(classes_to_ignore) == 0:
return anns return anns
@ -150,36 +173,37 @@ def load_anns(ann_file_name, raw_audio_dir):
anns = json.load(da) anns = json.load(da)
for aa in anns: for aa in anns:
aa['file_path'] = raw_audio_dir + aa['id'] aa["file_path"] = raw_audio_dir + aa["id"]
return anns return anns
def load_anns_from_path(ann_file_dir, raw_audio_dir): def load_anns_from_path(ann_file_dir, raw_audio_dir):
files = glob.glob(ann_file_dir + '*.json') files = glob.glob(ann_file_dir + "*.json")
anns = [] anns = []
for ff in files: for ff in files:
with open(ff) as da: with open(ff) as da:
ann = json.load(da) ann = json.load(da)
ann['file_path'] = raw_audio_dir + ann['id'] ann["file_path"] = raw_audio_dir + ann["id"]
anns.append(ann) anns.append(ann)
return anns return anns
class AverageMeter(object): class AverageMeter(object):
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self): def __init__(self):
self.val = 0 self.reset()
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1): def reset(self):
self.val = val self.val = 0
self.sum += val * n self.avg = 0
self.count += n self.sum = 0
self.avg = self.sum / self.count self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

475
bat_detect/types.py Normal file
View File

@ -0,0 +1,475 @@
"""Types used in the code base."""
from typing import List, NamedTuple, Optional
import numpy as np
import torch
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
try:
from typing import NotRequired
except ImportError:
from typing_extensions import NotRequired
__all__ = [
"Annotation",
"DetectionModel",
"FileAnnotations",
"ModelOutput",
"ModelParameters",
"NonMaximumSuppressionConfig",
"PredictionResults",
"ProcessingConfiguration",
"ResultParams",
"RunResults",
"SpectrogramParameters",
]
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
class ModelParameters(TypedDict):
"""Model parameters."""
model_name: str
"""Model name."""
num_filters: int
"""Number of filters."""
emb_dim: int
"""Embedding dimension."""
ip_height: int
"""Input height in pixels."""
resize_factor: float
"""Resize factor."""
class_names: List[str]
"""Class names. The model is trained to detect these classes."""
DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
tool.
"""
start_time: float
"""Start time in seconds."""
end_time: float
"""End time in seconds."""
low_freq: int
"""Low frequency in Hz."""
high_freq: int
"""High frequency in Hz."""
class_prob: float
"""Probability of class assignment."""
det_prob: float
"""Probability of detection."""
individual: str
"""Individual ID."""
event: str
"""Type of detected event."""
class FileAnnotations(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
id: str
"""File ID."""
annotated: bool
"""Whether file has been annotated."""
duration: float
"""Duration of audio file."""
issues: bool
"""Whether file has issues."""
time_exp: float
"""Time expansion factor."""
class_name: str
"""Class predicted at file level"""
notes: str
"""Notes of file."""
annotation: List[Annotation]
"""List of annotations."""
class RunResults(TypedDict):
"""Run results."""
pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool."""
spec_feats: NotRequired[List[np.ndarray]]
"""Spectrogram features."""
spec_feat_names: NotRequired[List[str]]
"""Spectrogram feature names."""
cnn_feats: NotRequired[List[np.ndarray]]
"""CNN features."""
cnn_feat_names: NotRequired[List[str]]
"""CNN feature names."""
spec_slices: NotRequired[List[np.ndarray]]
"""Spectrogram slices."""
class ResultParams(TypedDict):
"""Result parameters."""
class_names: List[str]
"""Class names."""
spec_features: bool
"""Whether to return spectrogram features."""
cnn_features: bool
"""Whether to return CNN features."""
spec_slices: bool
"""Whether to return spectrogram slices."""
class ProcessingConfiguration(TypedDict):
"""Parameters for processing audio files."""
# audio parameters
target_samp_rate: int
"""Target sampling rate of the audio."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Length of the FFT window in samples."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
spec_height: int
"""Height of the spectrogram in pixels."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
class_names: List[str]
"""Names of the classes the model can detect."""
detection_threshold: float
"""Threshold for detection probability."""
time_expansion: Optional[float]
"""Time expansion factor of the processed recordings."""
top_n: int
"""Number of top detections to keep."""
return_raw_preds: bool
"""Whether to return raw predictions."""
max_duration: Optional[float]
"""Maximum duration of audio file to process in seconds."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: int
"""Minimum frequency to consider in Hz."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
quiet: bool
"""Whether to suppress output."""
chunk_size: float
"""Size of chunks to process in seconds."""
cnn_features: bool
"""Whether to return CNN features."""
spec_features: bool
"""Whether to return spectrogram features."""
spec_slices: bool
"""Whether to return spectrogram slices."""
class ModelOutput(NamedTuple):
"""Output of the detection model.
Each of the tensors has a shape of
`(batch_size, num_channels,spec_height, spec_width)`.
Where `spec_height` and `spec_width` are the height and width of the
input spectrograms.
They contain localised information of:
1. The probability of a bounding box detection at the given location.
2. The predicted size of the bounding box at the given location.
3. The probabilities of each class at the given location.
4. Same as 3. but before softmax.
5. Features used to make the predictions at the given location.
"""
pred_det: torch.Tensor
"""Tensor with predict detection probabilities."""
pred_size: torch.Tensor
"""Tensor with predicted bounding box sizes."""
pred_class: torch.Tensor
"""Tensor with predicted class probabilities."""
pred_class_un_norm: torch.Tensor
"""Tensor with predicted class probabilities before softmax."""
features: torch.Tensor
"""Tensor with intermediate features."""
class PredictionResults(TypedDict):
"""Results of the prediction.
Each key is a list of length `num_detections` containing the
corresponding values for each detection.
"""
det_probs: np.ndarray
"""Detection probabilities."""
x_pos: np.ndarray
"""X position of the detection in pixels."""
y_pos: np.ndarray
"""Y position of the detection in pixels."""
bb_width: np.ndarray
"""Width of the detection in pixels."""
bb_height: np.ndarray
"""Height of the detection in pixels."""
start_times: np.ndarray
"""Start times of the detections in seconds."""
end_times: np.ndarray
"""End times of the detections in seconds."""
low_freqs: np.ndarray
"""Low frequencies of the detections in Hz."""
high_freqs: np.ndarray
"""High frequencies of the detections in Hz."""
class_probs: np.ndarray
"""Class probabilities."""
class DetectionModel(Protocol):
"""Protocol for detection models.
This protocol is used to define the interface for the detection models.
This allows us to use the same code for training and inference, even
though the models are different.
"""
num_classes: int
"""Number of classes the model can classify."""
emb_dim: int
"""Dimension of the embedding vector."""
num_filts: int
"""Number of filters in the model."""
resize_factor: float
"""Factor by which the input is resized."""
ip_height_rs: int
"""Height of the input image."""
def forward(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model."""
...
def __call__(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model."""
...
class NonMaximumSuppressionConfig(TypedDict):
"""Configuration for non-maximum suppression."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: int
"""Minimum frequency to consider in Hz."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Overlap of the FFT windows in seconds."""
resize_factor: float
"""Factor by which the input was resized."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
class HeatmapParameters(TypedDict):
"""Parameters that control the heatmap generation function."""
class_names: List[str]
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of the FFT windows overlap."""
resize_factor: float
"""Factor by which the input was resized."""
min_freq: int
"""Minimum frequency to consider in Hz."""
max_freq: int
"""Maximum frequency to consider in Hz."""
target_sigma: float
"""Sigma for the Gaussian kernel. Controls the width of the points in
the heatmap."""
class AnnotationGroup(TypedDict):
"""Group of annotations.
Each key is a numpy array of length `num_annotations` containing the
corresponding values for each annotation.
"""
start_times: np.ndarray
"""Start times of the annotations in seconds."""
end_times: np.ndarray
"""End times of the annotations in seconds."""
low_freqs: np.ndarray
"""Low frequencies of the annotations in Hz."""
high_freqs: np.ndarray
"""High frequencies of the annotations in Hz."""
class_ids: np.ndarray
"""Class IDs of the annotations."""
individual_ids: np.ndarray
"""Individual IDs of the annotations."""
x_inds: NotRequired[np.ndarray]
"""X coordinate of the annotations in the spectrogram."""
y_inds: NotRequired[np.ndarray]
"""Y coordinate of the annotations in the spectrogram."""

View File

@ -1,91 +1,207 @@
import numpy as np
from . import wavfile
import warnings import warnings
import torch from typing import Optional, Tuple
import librosa import librosa
import librosa.core.spectrum
import numpy as np
import torch
from bat_detect.detector.parameters import (
DENOISE_SPEC_AVG,
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
SPEC_SCALE,
)
from . import wavfile
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
__all__ = [
"load_audio",
"generate_spectrogram",
"pad_audio",
"SpectrogramParameters",
"DEFAULT_SPECTROGRAM_PARAMETERS",
]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length*sampling_rate) # int() uses floor nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap*nfft) noverlap = np.floor(fft_overlap * nfft)
return (time_in_file*sampling_rate-noverlap) / (nfft - noverlap) return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process # NOTE this is also defined in post_process
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length*sampling_rate) nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap*nfft) noverlap = np.floor(fft_overlap * nfft)
return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
#return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def generate_spectrogram(audio, sampling_rate, params, return_spec_for_viz=False, check_spec_size=True): def generate_spectrogram(
audio,
sampling_rate,
params,
return_spec_for_viz=False,
check_spec_size=True,
):
# generate spectrogram # generate spectrogram
spec = gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']) spec = gen_mag_spectrogram(
audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
# crop to min/max freq # crop to min/max freq
max_freq = round(params['max_freq']*params['fft_win_length']) max_freq = round(params["max_freq"] * params["fft_win_length"])
min_freq = round(params['min_freq']*params['fft_win_length']) min_freq = round(params["min_freq"] * params["fft_win_length"])
if spec.shape[0] < max_freq: if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0] freq_pad = max_freq - spec.shape[0]
spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)) spec = np.vstack(
spec_cropped = spec[-max_freq:spec.shape[0]-min_freq, :] (np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
)
spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
if params['spec_scale'] == 'log': if params["spec_scale"] == "log":
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) log_scaling = (
#log_scaling = (1.0 / sampling_rate)*0.1 2.0
#log_scaling = (1.0 / sampling_rate)*10e4 * (1.0 / sampling_rate)
spec = np.log1p(log_scaling*spec_cropped) * (
elif params['spec_scale'] == 'pcen': 1.0
/ (
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
).sum()
)
)
# log_scaling = (1.0 / sampling_rate)*0.1
# log_scaling = (1.0 / sampling_rate)*10e4
spec = np.log1p(log_scaling * spec_cropped)
elif params["spec_scale"] == "pcen":
spec = pcen(spec_cropped, sampling_rate) spec = pcen(spec_cropped, sampling_rate)
elif params['spec_scale'] == 'none':
elif params["spec_scale"] == "none":
pass pass
if params['denoise_spec_avg']: if params["denoise_spec_avg"]:
spec = spec - np.mean(spec, 1)[:, np.newaxis] spec = spec - np.mean(spec, 1)[:, np.newaxis]
spec.clip(min=0, out=spec) spec.clip(min=0, out=spec)
if params['max_scale_spec']: if params["max_scale_spec"]:
spec = spec / (spec.max() + 10e-6) spec = spec / (spec.max() + 10e-6)
# needs to be divisible by specific factor - if not it should have been padded # needs to be divisible by specific factor - if not it should have been padded
#if check_spec_size: # if check_spec_size:
#assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0) # assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
#assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0) # assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
# for visualization purposes - use log scaled spectrogram # for visualization purposes - use log scaled spectrogram
if return_spec_for_viz: if return_spec_for_viz:
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) log_scaling = (
spec_for_viz = np.log1p(log_scaling*spec_cropped).astype(np.float32) 2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
).sum()
)
)
spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
else: else:
spec_for_viz = None spec_for_viz = None
return spec, spec_for_viz return spec, spec_for_viz
def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, max_duration=False): def load_audio(
audio_file: str,
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
audio_file (str): Path to the audio file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
Returns:
sampling_rate: The sampling rate of the audio.
audio_raw: The audio signal in a numpy array.
Raises:
ValueError: If the audio file is stereo.
"""
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=wavfile.WavFileWarning) warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
#sampling_rate, audio_raw = wavfile.read(audio_file) # sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(audio_file, sr=None) audio_raw, sampling_rate = librosa.load(
audio_file,
sr=None,
dtype=np.float32,
)
if len(audio_raw.shape) > 1: if len(audio_raw.shape) > 1:
raise Exception('Currently does not handle stereo files') raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact sampling_rate = sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion # resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate sampling_rate_old = sampling_rate
sampling_rate = target_samp_rate sampling_rate = target_samp_rate
audio_raw = librosa.resample(audio_raw, orig_sr=sampling_rate_old, target_sr=sampling_rate, res_type='polyphase') if sampling_rate_old != sampling_rate:
audio_raw = librosa.resample(
audio_raw,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
# clipping maximum duration # clipping maximum duration
if max_duration is not False: if max_duration is not None:
max_duration = np.minimum(int(sampling_rate*max_duration), audio_raw.shape[0]) max_duration = int(
np.minimum(
int(sampling_rate * max_duration),
audio_raw.shape[0],
)
)
audio_raw = audio_raw[:max_duration] audio_raw = audio_raw[:max_duration]
# convert to float32 and scale # scale to [-1, 1]
audio_raw = audio_raw.astype(np.float32)
if scale: if scale:
audio_raw = audio_raw - audio_raw.mean() audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6) audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
@ -93,38 +209,53 @@ def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, ma
return sampling_rate, audio_raw return sampling_rate, audio_raw
def pad_audio(audio_raw, fs, ms, overlap_perc, resize_factor, divide_factor, fixed_width=None): def pad_audio(
audio_raw,
fs,
ms,
overlap_perc,
resize_factor,
divide_factor,
fixed_width=None,
):
# Adds zeros to the end of the raw data so that the generated sepctrogram # Adds zeros to the end of the raw data so that the generated sepctrogram
# will be evenly divisible by `divide_factor` # will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training # Also deals with very short audio clips and fixed_width during training
# This code could be clearer, clean up # This code could be clearer, clean up
nfft = int(ms*fs) nfft = int(ms * fs)
noverlap = int(overlap_perc*nfft) noverlap = int(overlap_perc * nfft)
step = nfft - noverlap step = nfft - noverlap
min_size = int(divide_factor*(1.0/resize_factor)) min_size = int(divide_factor * (1.0 / resize_factor))
spec_width = ((audio_raw.shape[0]-noverlap)//step) spec_width = (audio_raw.shape[0] - noverlap) // step
spec_width_rs = spec_width * resize_factor spec_width_rs = spec_width * resize_factor
if fixed_width is not None and spec_width < fixed_width: if fixed_width is not None and spec_width < fixed_width:
# too small # too small
# used during training to ensure all the batches are the same size # used during training to ensure all the batches are the same size
diff = fixed_width*step + noverlap - audio_raw.shape[0] diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype))) audio_raw = np.hstack(
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
)
elif fixed_width is not None and spec_width > fixed_width: elif fixed_width is not None and spec_width > fixed_width:
# too big # too big
# used during training to ensure all the batches are the same size # used during training to ensure all the batches are the same size
diff = fixed_width*step + noverlap - audio_raw.shape[0] diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = audio_raw[:diff] audio_raw = audio_raw[:diff]
elif spec_width_rs < min_size or (np.floor(spec_width_rs) % divide_factor) != 0: elif (
spec_width_rs < min_size
or (np.floor(spec_width_rs) % divide_factor) != 0
):
# need to be at least min_size # need to be at least min_size
div_amt = np.ceil(spec_width_rs / float(divide_factor)) div_amt = np.ceil(spec_width_rs / float(divide_factor))
div_amt = np.maximum(1, div_amt) div_amt = np.maximum(1, div_amt)
target_size = int(div_amt*divide_factor*(1.0/resize_factor)) target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
diff = target_size*step + noverlap - audio_raw.shape[0] diff = target_size * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype))) audio_raw = np.hstack(
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
)
return audio_raw return audio_raw
@ -133,14 +264,16 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
# Computes magnitude spectrogram by specifying time. # Computes magnitude spectrogram by specifying time.
x = x.astype(np.float32) x = x.astype(np.float32)
nfft = int(ms*fs) nfft = int(ms * fs)
noverlap = int(overlap_perc*nfft) noverlap = int(overlap_perc * nfft)
# window data # window data
step = nfft - noverlap step = nfft - noverlap
# compute spec # compute spec
spec, _ = librosa.core.spectrum._spectrogram(y=x, power=1, n_fft=nfft, hop_length=step, center=False) spec, _ = librosa.core.spectrum._spectrogram(
y=x, power=1, n_fft=nfft, hop_length=step, center=False
)
# remove DC component and flip vertical orientation # remove DC component and flip vertical orientation
spec = np.flipud(spec[1:, :]) spec = np.flipud(spec[1:, :])
@ -149,8 +282,8 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc): def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
nfft = int(ms*fs) nfft = int(ms * fs)
nstep = round((1.0-overlap_perc)*nfft) nstep = round((1.0 - overlap_perc) * nfft)
han_win = torch.hann_window(nfft, periodic=False).to(x.device) han_win = torch.hann_window(nfft, periodic=False).to(x.device)
@ -158,12 +291,14 @@ def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
spec = complex_spec.pow(2.0).sum(-1) spec = complex_spec.pow(2.0).sum(-1)
# remove DC component and flip vertically # remove DC component and flip vertically
spec = torch.flipud(spec[0, 1:,:]) spec = torch.flipud(spec[0, 1:, :])
return spec return spec
def pcen(spec_cropped, sampling_rate): def pcen(spec_cropped, sampling_rate):
# TODO should be passing hop_length too i.e. step # TODO should be passing hop_length too i.e. step
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate/10).astype(np.float32) spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
np.float32
)
return spec return spec

File diff suppressed because it is too large Load Diff

View File

@ -1,63 +1,107 @@
import numpy as np
import matplotlib.pyplot as plt
import json import json
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patches from matplotlib import patches
from matplotlib.collections import PatchCollection from matplotlib.collections import PatchCollection
from sklearn.metrics import confusion_matrix
from . import audio_utils as au
def create_box_image(spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False): def create_box_image(
spec,
fig,
detections_ip,
start_time,
end_time,
duration,
params,
max_val,
hide_axis=True,
plot_class_names=False,
):
# filter detections # filter detections
stop_time = start_time + duration stop_time = start_time + duration
detections = [] detections = []
for bb in detections_ip: for bb in detections_ip:
if (bb['start_time'] >= start_time) and (bb['start_time'] < stop_time-0.02): #(bb['end_time'] < end_time): if (bb["start_time"] >= start_time) and (
bb["start_time"] < stop_time - 0.02
): # (bb['end_time'] < end_time):
detections.append(bb) detections.append(bb)
# create figure # create figure
freq_scale = 1000 # turn Hz to kHz freq_scale = 1000 # turn Hz to kHz
min_freq = params['min_freq']//freq_scale min_freq = params["min_freq"] // freq_scale
max_freq = params['max_freq']//freq_scale max_freq = params["max_freq"] // freq_scale
y_extent = [0, duration, min_freq, max_freq] y_extent = [0, duration, min_freq, max_freq]
if hide_axis: if hide_axis:
ax = plt.Axes(fig, [0., 0., 1., 1.]) ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off() ax.set_axis_off()
fig.add_axes(ax) fig.add_axes(ax)
else: else:
ax = plt.gca() ax = plt.gca()
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=max_val) plt.imshow(
spec,
aspect="auto",
cmap="plasma",
extent=y_extent,
vmin=0,
vmax=max_val,
)
boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time) boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time)
ax.add_collection(PatchCollection(boxes, match_original=True)) ax.add_collection(PatchCollection(boxes, match_original=True))
plt.grid(False) plt.grid(False)
if plot_class_names: if plot_class_names:
for ii, bb in enumerate(boxes): for ii, bb in enumerate(boxes):
txt = ' '.join([sp[:3] for sp in detections_ip[ii]['class'].split(' ')]) txt = " ".join(
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} [sp[:3] for sp in detections_ip[ii]["class"].split(" ")]
)
font_info = {
"color": "white",
"size": 10,
"weight": "bold",
"alpha": bb.get_alpha(),
}
y_pos = bb.get_xy()[1] + bb.get_height() y_pos = bb.get_xy()[1] + bb.get_height()
if y_pos > (max_freq - 10): if y_pos > (max_freq - 10):
y_pos = max_freq - 10 y_pos = max_freq - 10
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title_text='', anns=None): def save_ann_spec(
op_path,
spec,
min_freq,
max_freq,
duration,
start_time,
title_text="",
anns=None,
):
# create figure and plot boxes # create figure and plot boxes
freq_scale = 1000 # turn Hz to kHz freq_scale = 1000 # turn Hz to kHz
min_freq = min_freq//freq_scale min_freq = min_freq // freq_scale
max_freq = max_freq//freq_scale max_freq = max_freq // freq_scale
y_extent = [0, duration, min_freq, max_freq] y_extent = [0, duration, min_freq, max_freq]
plt.close('all') plt.close("all")
fig = plt.figure(0, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100) fig = plt.figure(
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=spec.max()*1.1) 0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
)
plt.imshow(
spec,
aspect="auto",
cmap="plasma",
extent=y_extent,
vmin=0,
vmax=spec.max() * 1.1,
)
plt.ylabel('Freq - kHz') plt.ylabel("Freq - kHz")
plt.xlabel('Time - secs') plt.xlabel("Time - secs")
if title_text != '': if title_text != "":
plt.title(title_text) plt.title(title_text)
plt.tight_layout() plt.tight_layout()
@ -66,122 +110,185 @@ def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title
boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time) boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time)
plt.gca().add_collection(PatchCollection(boxes, match_original=True)) plt.gca().add_collection(PatchCollection(boxes, match_original=True))
for ii, bb in enumerate(boxes): for ii, bb in enumerate(boxes):
txt = ' '.join([sp[:3] for sp in anns[ii]['class'].split(' ')]) txt = " ".join([sp[:3] for sp in anns[ii]["class"].split(" ")])
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} font_info = {
"color": "white",
"size": 10,
"weight": "bold",
"alpha": bb.get_alpha(),
}
y_pos = bb.get_xy()[1] + bb.get_height() y_pos = bb.get_xy()[1] + bb.get_height()
if y_pos > (max_freq - 10): if y_pos > (max_freq - 10):
y_pos = max_freq - 10 y_pos = max_freq - 10
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
print('Saving figure to:', op_path) print("Saving figure to:", op_path)
plt.savefig(op_path) plt.savefig(op_path)
def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False): def plot_pts(
fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False
):
plt.figure(fig_id) plt.figure(fig_id)
un_class, labels = np.unique(class_names, return_inverse=True) un_class, labels = np.unique(class_names, return_inverse=True)
un_labels = np.unique(labels) un_labels = np.unique(labels)
if un_labels.shape[0] > len(colors): if un_labels.shape[0] > len(colors):
colors = [plt.cm.jet(float(ii)/un_labels.shape[0]) for ii in un_labels] colors = [
plt.cm.jet(float(ii) / un_labels.shape[0]) for ii in un_labels
]
for ii, u in enumerate(un_labels): for ii, u in enumerate(un_labels):
inds = np.where(labels==u)[0] inds = np.where(labels == u)[0]
plt.scatter(feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size) plt.scatter(
feats[inds, 0],
feats[inds, 1],
c=colors[ii],
label=str(un_class[ii]),
s=marker_size,
)
if plot_legend: if plot_legend:
plt.legend() plt.legend()
plt.xticks([]) plt.xticks([])
plt.yticks([]) plt.yticks([])
plt.title('downsampled features') plt.title("downsampled features")
def plot_bounding_box_patch(pred, freq_scale, ecolor='w'): def plot_bounding_box_patch(pred, freq_scale, ecolor="w"):
patch_collect = [] patch_collect = []
for bb in range(len(pred['start_times'])): for bb in range(len(pred["start_times"])):
xx = pred['start_times'][bb] xx = pred["start_times"][bb]
ww = pred['end_times'][bb] - pred['start_times'][bb] ww = pred["end_times"][bb] - pred["start_times"][bb]
yy = pred['low_freqs'][bb] / freq_scale yy = pred["low_freqs"][bb] / freq_scale
hh = (pred['high_freqs'][bb] - pred['low_freqs'][bb]) / freq_scale hh = (pred["high_freqs"][bb] - pred["low_freqs"][bb]) / freq_scale
if 'det_probs' in pred.keys(): if "det_probs" in pred.keys():
alpha_val = pred['det_probs'][bb] alpha_val = pred["det_probs"][bb]
else: else:
alpha_val = 1.0 alpha_val = 1.0
patch_collect.append(patches.Rectangle((xx, yy), ww, hh, linewidth=1, patch_collect.append(
edgecolor=ecolor, facecolor='none', alpha=alpha_val)) patches.Rectangle(
(xx, yy),
ww,
hh,
linewidth=1,
edgecolor=ecolor,
facecolor="none",
alpha=alpha_val,
)
)
return patch_collect return patch_collect
def plot_bounding_box_patch_ann(anns, freq_scale, start_time): def plot_bounding_box_patch_ann(anns, freq_scale, start_time):
patch_collect = [] patch_collect = []
for aa in range(len(anns)): for aa in range(len(anns)):
xx = anns[aa]['start_time'] - start_time xx = anns[aa]["start_time"] - start_time
ww = anns[aa]['end_time'] - anns[aa]['start_time'] ww = anns[aa]["end_time"] - anns[aa]["start_time"]
yy = anns[aa]['low_freq'] / freq_scale yy = anns[aa]["low_freq"] / freq_scale
hh = (anns[aa]['high_freq'] - anns[aa]['low_freq']) / freq_scale hh = (anns[aa]["high_freq"] - anns[aa]["low_freq"]) / freq_scale
if 'det_prob' in anns[aa]: if "det_prob" in anns[aa]:
alpha = anns[aa]['det_prob'] alpha = anns[aa]["det_prob"]
else: else:
alpha = 1.0 alpha = 1.0
patch_collect.append(patches.Rectangle((xx,yy), ww, hh, linewidth=1, patch_collect.append(
edgecolor='w', facecolor='none', alpha=alpha)) patches.Rectangle(
(xx, yy),
ww,
hh,
linewidth=1,
edgecolor="w",
facecolor="none",
alpha=alpha,
)
)
return patch_collect return patch_collect
def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title, def plot_spec(
op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True): spec,
sampling_rate,
duration,
gt,
pred,
params,
plot_title,
op_file_name,
pred_2d_hm,
plot_boxes=True,
fixed_aspect=True,
):
if fixed_aspect: if fixed_aspect:
# ouptut image will be this width irrespective of the duration of the audio file # ouptut image will be this width irrespective of the duration of the audio file
width = 12 width = 12
else: else:
width = 12*duration width = 12 * duration
fig = plt.figure(1, figsize=(width, 8)) fig = plt.figure(1, figsize=(width, 8))
ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h
ax1 = plt.axes([0.05, 0.33, 0.9, 0.30]) ax1 = plt.axes([0.05, 0.33, 0.9, 0.30])
ax2 = plt.axes([0.05, 0.01, 0.9, 0.30]) ax2 = plt.axes([0.05, 0.01, 0.9, 0.30])
freq_scale = 1000 # turn Hz in kHz freq_scale = 1000 # turn Hz in kHz
#duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) # duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
y_extent = [0, duration, params['min_freq']//freq_scale, params['max_freq']//freq_scale] y_extent = [
0,
duration,
params["min_freq"] // freq_scale,
params["max_freq"] // freq_scale,
]
# plot gt boxes # plot gt boxes
ax0.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) ax0.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
ax0.xaxis.set_ticklabels([]) ax0.xaxis.set_ticklabels([])
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} font_info = {"color": "white", "size": 12, "weight": "bold"}
ax0.text(0, params['min_freq']//freq_scale, 'Ground Truth', fontdict=font_info) ax0.text(
0, params["min_freq"] // freq_scale, "Ground Truth", fontdict=font_info
)
plt.grid(False) plt.grid(False)
if plot_boxes: if plot_boxes:
boxes = plot_bounding_box_patch(gt, freq_scale) boxes = plot_bounding_box_patch(gt, freq_scale)
ax0.add_collection(PatchCollection(boxes, match_original=True)) ax0.add_collection(PatchCollection(boxes, match_original=True))
for ii, bb in enumerate(boxes): for ii, bb in enumerate(boxes):
class_id = int(gt['class_ids'][ii]) class_id = int(gt["class_ids"][ii])
if class_id < 0: if class_id < 0:
txt = params['generic_class'][0] txt = params["generic_class"][0]
else: else:
txt = params['class_names_short'][class_id] txt = params["class_names_short"][class_id]
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} font_info = {
"color": "white",
"size": 10,
"weight": "bold",
"alpha": bb.get_alpha(),
}
y_pos = bb.get_xy()[1] + bb.get_height() y_pos = bb.get_xy()[1] + bb.get_height()
ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
# plot predicted boxes # plot predicted boxes
ax1.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) ax1.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
ax1.xaxis.set_ticklabels([]) ax1.xaxis.set_ticklabels([])
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} font_info = {"color": "white", "size": 12, "weight": "bold"}
ax1.text(0, params['min_freq']//freq_scale, 'Prediction', fontdict=font_info) ax1.text(
0, params["min_freq"] // freq_scale, "Prediction", fontdict=font_info
)
plt.grid(False) plt.grid(False)
if plot_boxes: if plot_boxes:
boxes = plot_bounding_box_patch(pred, freq_scale) boxes = plot_bounding_box_patch(pred, freq_scale)
ax1.add_collection(PatchCollection(boxes, match_original=True)) ax1.add_collection(PatchCollection(boxes, match_original=True))
for ii, bb in enumerate(boxes): for ii, bb in enumerate(boxes):
if pred['class_probs'].shape[0] > len(params['class_names_short']): if pred["class_probs"].shape[0] > len(params["class_names_short"]):
class_id = pred['class_probs'][:-1, ii].argmax() class_id = pred["class_probs"][:-1, ii].argmax()
else: else:
class_id = pred['class_probs'][:, ii].argmax() class_id = pred["class_probs"][:, ii].argmax()
txt = params['class_names_short'][class_id] txt = params["class_names_short"][class_id]
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} font_info = {
"color": "white",
"size": 10,
"weight": "bold",
"alpha": bb.get_alpha(),
}
y_pos = bb.get_xy()[1] + bb.get_height() y_pos = bb.get_xy()[1] + bb.get_height()
ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
@ -190,10 +297,18 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min() min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min()
max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max() max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max()
ax2.imshow(pred_2d_hm, aspect='auto', cmap='plasma', extent=y_extent, clim=[min_val, max_val]) ax2.imshow(
#ax2.xaxis.set_ticklabels([]) pred_2d_hm,
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} aspect="auto",
ax2.text(0, params['min_freq']//freq_scale, 'Heatmap', fontdict=font_info) cmap="plasma",
extent=y_extent,
clim=[min_val, max_val],
)
# ax2.xaxis.set_ticklabels([])
font_info = {"color": "white", "size": 12, "weight": "bold"}
ax2.text(
0, params["min_freq"] // freq_scale, "Heatmap", fontdict=font_info
)
plt.grid(False) plt.grid(False)
@ -204,107 +319,149 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
plt.close(1) plt.close(1)
def plot_pr_curve(op_dir, plt_title, file_name, results, file_type='png', title_text=''): def plot_pr_curve(
precision = results['precision'] op_dir, plt_title, file_name, results, file_type="png", title_text=""
recall = results['recall'] ):
avg_prec = results['avg_prec'] precision = results["precision"]
recall = results["recall"]
avg_prec = results["avg_prec"]
plt.figure(0, figsize=(10,8)) plt.figure(0, figsize=(10, 8))
plt.plot(recall, precision) plt.plot(recall, precision)
plt.ylabel('Precision', fontsize=20) plt.ylabel("Precision", fontsize=20)
plt.xlabel('Recall', fontsize=20) plt.xlabel("Recall", fontsize=20)
if title_text != '': if title_text != "":
plt.title(title_text, fontdict={'fontsize': 28}) plt.title(title_text, fontdict={"fontsize": 28})
else: else:
plt.title(plt_title + ' {:.3f}\n'.format(avg_prec)) plt.title(plt_title + " {:.3f}\n".format(avg_prec))
plt.xlim(0,1.02) plt.xlim(0, 1.02)
plt.ylim(0,1.02) plt.ylim(0, 1.02)
plt.grid(True) plt.grid(True)
plt.tight_layout() plt.tight_layout()
plt.savefig(op_dir + file_name + '.' + file_type) plt.savefig(op_dir + file_name + "." + file_type)
plt.close(0) plt.close(0)
def plot_pr_curve_class(op_dir, plt_title, file_name, results, file_type='png', title_text=''): def plot_pr_curve_class(
plt.figure(0, figsize=(10,8)) op_dir, plt_title, file_name, results, file_type="png", title_text=""
plt.ylabel('Precision', fontsize=20) ):
plt.xlabel('Recall', fontsize=20) plt.figure(0, figsize=(10, 8))
plt.xlim(0,1.02) plt.ylabel("Precision", fontsize=20)
plt.ylim(0,1.02) plt.xlabel("Recall", fontsize=20)
plt.xlim(0, 1.02)
plt.ylim(0, 1.02)
plt.grid(True) plt.grid(True)
linestyles = ['-', ':', '--'] linestyles = ["-", ":", "--"]
markers = ['o', 'v', '>', '^', '<', 's', 'P', 'X', '*'] markers = ["o", "v", ">", "^", "<", "s", "P", "X", "*"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
# plot the PR curves # plot the PR curves
for ii, rr in enumerate(results['class_pr']): for ii, rr in enumerate(results["class_pr"]):
class_name = ' '.join([sp[:3] for sp in rr['name'].split(' ')]) class_name = " ".join([sp[:3] for sp in rr["name"].split(" ")])
cur_color = colors[int(ii%10)] cur_color = colors[int(ii % 10)]
plt.plot(rr['recall'], rr['precision'], label=class_name, color=cur_color, plt.plot(
linestyle=linestyles[int(ii//10)], lw=2.5) rr["recall"],
rr["precision"],
label=class_name,
color=cur_color,
linestyle=linestyles[int(ii // 10)],
lw=2.5,
)
#print(class_name) # print(class_name)
# plot the location of the confidence threshold values # plot the location of the confidence threshold values
for jj, tt in enumerate(rr['thresholds']): for jj, tt in enumerate(rr["thresholds"]):
ind = rr['thresholds_inds'][jj] ind = rr["thresholds_inds"][jj]
if ind > -1: if ind > -1:
plt.plot(rr['recall'][ind], rr['precision'][ind], markers[jj], plt.plot(
color=cur_color, ms=10) rr["recall"][ind],
#print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3)) rr["precision"][ind],
markers[jj],
color=cur_color,
ms=10,
)
# print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3))
if title_text != '': if title_text != "":
plt.title(title_text, fontdict={'fontsize': 28}) plt.title(title_text, fontdict={"fontsize": 28})
else: else:
plt.title(plt_title + ' {:.3f}\n'.format(results['avg_prec_class'])) plt.title(plt_title + " {:.3f}\n".format(results["avg_prec_class"]))
plt.legend(loc='lower left', prop={'size': 14}) plt.legend(loc="lower left", prop={"size": 14})
plt.tight_layout() plt.tight_layout()
plt.savefig(op_dir + file_name + '.' + file_type) plt.savefig(op_dir + file_name + "." + file_type)
plt.close(0) plt.close(0)
def plot_confusion_matrix(op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type='png', title_text=''): def plot_confusion_matrix(
op_dir,
op_file,
gt,
pred,
file_acc,
class_names_long,
verbose=False,
file_type="png",
title_text="",
):
# shorten the class names for plotting # shorten the class names for plotting
class_names = [] class_names = []
for cc in class_names_long: for cc in class_names_long:
class_name_sm = ''.join([cc_sm[:3] + ' ' for cc_sm in cc.split(' ')])[:-1] class_name_sm = "".join([cc_sm[:3] + " " for cc_sm in cc.split(" ")])[
:-1
]
class_names.append(class_name_sm) class_names.append(class_name_sm)
num_classes = len(class_names) num_classes = len(class_names)
cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32) cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(
np.float32
)
cm_norm = cm.sum(1) cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0] valid_inds = np.where(cm_norm > 0)[0]
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
cm[np.where(cm_norm ==- 0)[0], :] = np.nan cm[np.where(cm_norm == -0)[0], :] = np.nan
if verbose: if verbose:
print('Per class accuracy:') print("Per class accuracy:")
str_len = np.max([len(cc) for cc in class_names_long]) + 5 str_len = np.max([len(cc) for cc in class_names_long]) + 5
accs = np.diag(cm) accs = np.diag(cm)
for ii, cc in enumerate(class_names_long): for ii, cc in enumerate(class_names_long):
if np.isnan(accs[ii]): if np.isnan(accs[ii]):
print(str(ii).ljust(5) + cc.ljust(str_len)) print(str(ii).ljust(5) + cc.ljust(str_len))
else: else:
print(str(ii).ljust(5) + cc.ljust(str_len) + '{:.2f}'.format(accs[ii]*100)) print(
str(ii).ljust(5)
+ cc.ljust(str_len)
+ "{:.2f}".format(accs[ii] * 100)
)
plt.figure(0, figsize=(10,8)) plt.figure(0, figsize=(10, 8))
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
plt.colorbar() plt.colorbar()
plt.xticks(np.arange(cm.shape[1]), class_names, rotation='vertical') plt.xticks(np.arange(cm.shape[1]), class_names, rotation="vertical")
plt.yticks(np.arange(cm.shape[0]), class_names) plt.yticks(np.arange(cm.shape[0]), class_names)
plt.xlabel('Predicted', fontsize=20) plt.xlabel("Predicted", fontsize=20)
plt.ylabel('Ground Truth', fontsize=20) plt.ylabel("Ground Truth", fontsize=20)
if title_text != '': if title_text != "":
plt.title(title_text, fontdict={'fontsize': 28}) plt.title(title_text, fontdict={"fontsize": 28})
else: else:
plt.title(op_file + ' {:.3f}\n'.format(file_acc)) plt.title(op_file + " {:.3f}\n".format(file_acc))
plt.tight_layout() plt.tight_layout()
plt.savefig(op_dir + op_file + '.' + file_type) plt.savefig(op_dir + op_file + "." + file_type)
plt.close('all') plt.close("all")
class LossPlotter(object): class LossPlotter(object):
def __init__(self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False): def __init__(
self,
op_file_name,
duration,
labels,
ylim,
class_names,
axis_labels=None,
logy=False,
):
self.reset() self.reset()
self.op_file_name = op_file_name self.op_file_name = op_file_name
self.duration = duration # length of x axis self.duration = duration # length of x axis
@ -327,11 +484,16 @@ class LossPlotter(object):
self.save_confusion_matrix(gt, pred) self.save_confusion_matrix(gt, pred)
def save_plot(self): def save_plot(self):
linestyles = ['-', ':', '--'] linestyles = ["-", ":", "--"]
plt.figure(0, figsize=(8,5)) plt.figure(0, figsize=(8, 5))
for ii in range(len(self.vals[0])): for ii in range(len(self.vals[0])):
l_vals = [vv[ii] for vv in self.vals] l_vals = [vv[ii] for vv in self.vals]
plt.plot(self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii//10)]) plt.plot(
self.epochs,
l_vals,
label=self.labels[ii],
linestyle=linestyles[int(ii // 10)],
)
plt.xlim(0, np.maximum(self.duration, len(self.vals))) plt.xlim(0, np.maximum(self.duration, len(self.vals)))
if self.ylim is not None: if self.ylim is not None:
plt.ylim(self.ylim[0], self.ylim[1]) plt.ylim(self.ylim[0], self.ylim[1])
@ -339,33 +501,41 @@ class LossPlotter(object):
plt.xlabel(self.axis_labels[0]) plt.xlabel(self.axis_labels[0])
plt.ylabel(self.axis_labels[1]) plt.ylabel(self.axis_labels[1])
if self.logy: if self.logy:
plt.gca().set_yscale('log') plt.gca().set_yscale("log")
plt.grid(True) plt.grid(True)
plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.0) plt.legend(
bbox_to_anchor=(1.01, 1), loc="upper left", borderaxespad=0.0
)
plt.tight_layout() plt.tight_layout()
plt.savefig(self.op_file_name) plt.savefig(self.op_file_name)
plt.close(0) plt.close(0)
def save_json(self): def save_json(self):
data = {} data = {}
data['epochs'] = self.epochs data["epochs"] = self.epochs
for ii in range(len(self.vals[0])): for ii in range(len(self.vals[0])):
data[self.labels[ii]] = [round(vv[ii],4) for vv in self.vals] data[self.labels[ii]] = [round(vv[ii], 4) for vv in self.vals]
with open(self.op_file_name[:-4] + '.json', 'w') as da: with open(self.op_file_name[:-4] + ".json", "w") as da:
json.dump(data, da, indent=2) json.dump(data, da, indent=2)
def save_confusion_matrix(self, gt, pred): def save_confusion_matrix(self, gt, pred):
plt.figure(0) plt.figure(0)
cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(np.float32) cm = confusion_matrix(
gt, pred, labels=np.arange(len(self.class_names))
).astype(np.float32)
cm_norm = cm.sum(1) cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0] valid_inds = np.where(cm_norm > 0)[0]
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] cm[valid_inds, :] = (
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
)
plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
plt.colorbar() plt.colorbar()
plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation='vertical') plt.xticks(
np.arange(cm.shape[1]), self.class_names, rotation="vertical"
)
plt.yticks(np.arange(cm.shape[0]), self.class_names) plt.yticks(np.arange(cm.shape[0]), self.class_names)
plt.xlabel('Predicted') plt.xlabel("Predicted")
plt.ylabel('Ground Truth') plt.ylabel("Ground Truth")
plt.tight_layout() plt.tight_layout()
plt.savefig(self.op_file_name[:-4] + '_cm.png') plt.savefig(self.op_file_name[:-4] + "_cm.png")
plt.close(0) plt.close(0)

View File

@ -1,19 +1,46 @@
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patches from matplotlib import patches
from sklearn.svm import LinearSVC
from matplotlib.axes._axes import _log as matplotlib_axes_logger from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR') from sklearn.svm import LinearSVC
matplotlib_axes_logger.setLevel("ERROR")
colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', colors = [
'#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff', "#e6194B",
'#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', "#3cb44b",
'#000075', '#a9a9a9'] "#ffe119",
"#4363d8",
"#f58231",
"#911eb4",
"#42d4f4",
"#f032e6",
"#bfef45",
"#fabebe",
"#469990",
"#e6beff",
"#9A6324",
"#fffac8",
"#800000",
"#aaffc3",
"#808000",
"#ffd8b1",
"#000075",
"#a9a9a9",
]
class InteractivePlotter: class InteractivePlotter:
def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training): def __init__(
self,
feats_ds,
feats,
spec_slices,
call_info,
freq_lims,
allow_training,
):
""" """
Plots 2D low dimensional features on left and corresponding spectgrams on Plots 2D low dimensional features on left and corresponding spectgrams on
the right. the right.
@ -24,78 +51,123 @@ class InteractivePlotter:
self.spec_slices = spec_slices self.spec_slices = spec_slices
self.call_info = call_info self.call_info = call_info
#_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True) # _, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
self.labels = np.zeros(len(call_info), dtype=np.int) self.labels = np.zeros(len(call_info), dtype=np.int)
self.annotated = np.zeros(self.labels.shape[0], dtype=np.int) # can populate this with 1's where we have labels self.annotated = np.zeros(
self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))] self.labels.shape[0], dtype=np.int
) # can populate this with 1's where we have labels
self.labels_cols = [
colors[self.labels[ii]] for ii in range(len(self.labels))
]
self.freq_lims = freq_lims self.freq_lims = freq_lims
self.allow_training = allow_training self.allow_training = allow_training
self.pt_size = 5.0 self.pt_size = 5.0
self.spec_pad = 0.2 # this much padding has been applied to the spec slices self.spec_pad = (
0.2 # this much padding has been applied to the spec slices
)
self.fig_width = 12 self.fig_width = 12
self.fig_height = 8 self.fig_height = 8
self.current_id = 0 self.current_id = 0
max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices]) max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices])
self.max_width = self.spec_slices[max_ind].shape[1] self.max_width = self.spec_slices[max_ind].shape[1]
self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width)) self.blank_spec = np.zeros(
(self.spec_slices[0].shape[0], self.max_width)
)
def plot(self, fig_id): def plot(self, fig_id):
self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height), self.fig, self.ax = plt.subplots(
gridspec_kw={'width_ratios': [2, 1]}) nrows=1,
ncols=2,
num=fig_id,
figsize=(self.fig_width, self.fig_height),
gridspec_kw={"width_ratios": [2, 1]},
)
plt.tight_layout() plt.tight_layout()
# plot 2D TNSE features # plot 2D TNSE features
self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1], self.low_dim_plt = self.ax[0].scatter(
c=self.labels_cols, s=self.pt_size, picker=5) self.feats_ds[:, 0],
self.ax[0].set_title('TSNE of Call Features') self.feats_ds[:, 1],
c=self.labels_cols,
s=self.pt_size,
picker=5,
)
self.ax[0].set_title("TSNE of Call Features")
self.ax[0].set_xticks([]) self.ax[0].set_xticks([])
self.ax[0].set_yticks([]) self.ax[0].set_yticks([])
# plot clip from spectrogram # plot clip from spectrogram
spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1]) spec_min_max = (
self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto') 0,
self.blank_spec.shape[1],
self.freq_lims[0],
self.freq_lims[1],
)
self.ax[1].imshow(
self.blank_spec, extent=spec_min_max, cmap="plasma", aspect="auto"
)
self.spec_im = self.ax[1].get_images()[0] self.spec_im = self.ax[1].get_images()[0]
self.ax[1].set_title('Spectrogram') self.ax[1].set_title("Spectrogram")
self.ax[1].grid(color='w', linewidth=0.5) self.ax[1].grid(color="w", linewidth=0.5)
self.ax[1].set_xticks([]) self.ax[1].set_xticks([])
self.ax[1].set_ylabel('kHz') self.ax[1].set_ylabel("kHz")
bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False) bbox_orig = patches.Rectangle(
(0, 0), 0, 0, edgecolor="w", linewidth=0, fill=False
)
self.ax[1].add_patch(bbox_orig) self.ax[1].add_patch(bbox_orig)
self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points', self.annot = self.ax[0].annotate(
bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->')) "",
xy=(0, 0),
xytext=(20, 20),
textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"),
)
self.annot.set_visible(False) self.annot.set_visible(False)
self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover) self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_hover)
self.fig.canvas.mpl_connect('key_press_event', self.key_press) self.fig.canvas.mpl_connect("key_press_event", self.key_press)
def mouse_hover(self, event): def mouse_hover(self, event):
vis = self.annot.get_visible() vis = self.annot.get_visible()
if event.inaxes == self.ax[0]: if event.inaxes == self.ax[0]:
cont, ind = self.low_dim_plt.contains(event) cont, ind = self.low_dim_plt.contains(event)
if cont: if cont:
self.current_id = ind['ind'][0] self.current_id = ind["ind"][0]
# copy spec into full window - probably a better way of doing this # copy spec into full window - probably a better way of doing this
new_spec = self.blank_spec.copy() new_spec = self.blank_spec.copy()
w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2 w_diff = (
new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id] self.blank_spec.shape[1]
- self.spec_slices[self.current_id].shape[1]
) // 2
new_spec[
:,
w_diff : self.spec_slices[self.current_id].shape[1]
+ w_diff,
] = self.spec_slices[self.current_id]
self.spec_im.set_data(new_spec) self.spec_im.set_data(new_spec)
self.spec_im.set_clim(vmin=0, vmax=new_spec.max()) self.spec_im.set_clim(vmin=0, vmax=new_spec.max())
# draw bounding box around call # draw bounding box around call
self.ax[1].patches[0].remove() self.ax[1].patches[0].remove()
spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad) spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
xx = w_diff + self.spec_pad*spec_width_orig 1.0 + 2.0 * self.spec_pad
)
xx = w_diff + self.spec_pad * spec_width_orig
ww = spec_width_orig ww = spec_width_orig
yy = self.call_info[self.current_id]['low_freq']/1000 yy = self.call_info[self.current_id]["low_freq"] / 1000
hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000 hh = (
bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False) self.call_info[self.current_id]["high_freq"]
- self.call_info[self.current_id]["low_freq"]
) / 1000
bbox = patches.Rectangle(
(xx, yy), ww, hh, edgecolor="r", linewidth=0.5, fill=False
)
self.ax[1].add_patch(bbox) self.ax[1].add_patch(bbox)
# update annotation arrow # update annotation arrow
@ -104,38 +176,52 @@ class InteractivePlotter:
self.annot.set_visible(True) self.annot.set_visible(True)
# write call info # write call info
info_str = self.call_info[self.current_id]['file_name'] + ', time=' \ info_str = (
+ str(round(self.call_info[self.current_id]['start_time'],3)) \ self.call_info[self.current_id]["file_name"]
+ ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3)) + ", time="
+ str(
round(self.call_info[self.current_id]["start_time"], 3)
)
+ ", prob="
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
)
self.ax[0].set_xlabel(info_str) self.ax[0].set_xlabel(info_str)
# redraw # redraw
self.fig.canvas.draw_idle() self.fig.canvas.draw_idle()
def key_press(self, event): def key_press(self, event):
if event.key.isdigit(): if event.key.isdigit():
self.labels_cols[self.current_id] = colors[int(event.key)] self.labels_cols[self.current_id] = colors[int(event.key)]
self.labels[self.current_id] = int(event.key) self.labels[self.current_id] = int(event.key)
self.annotated[self.current_id] = 1 self.annotated[self.current_id] = 1
elif event.key == 'enter' and self.allow_training: elif event.key == "enter" and self.allow_training:
self.train_classifier() self.train_classifier()
elif event.key == 'x' and self.allow_training: elif event.key == "x" and self.allow_training:
self.get_classifier_params() self.get_classifier_params()
self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1], self.ax[0].scatter(
c=self.labels_cols, s=self.pt_size) self.feats_ds[:, 0],
self.feats_ds[:, 1],
c=self.labels_cols,
s=self.pt_size,
)
self.fig.canvas.draw_idle() self.fig.canvas.draw_idle()
def train_classifier(self): def train_classifier(self):
# TODO maybe it's better to classify in 2D space - but then can't be linear ... # TODO maybe it's better to classify in 2D space - but then can't be linear ...
inds = np.where(self.annotated == 1)[0] inds = np.where(self.annotated == 1)[0]
labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True) labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True)
if labs_un.shape[0] > 1: # needs at least 2 classes if labs_un.shape[0] > 1: # needs at least 2 classes
self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001, self.clf = LinearSVC(
intercept_scaling=1.0, max_iter=2000) C=1.0,
penalty="l2",
loss="squared_hinge",
tol=0.0001,
intercept_scaling=1.0,
max_iter=2000,
)
self.clf.fit(self.feats[inds, :], self.labels[inds]) self.clf.fit(self.feats[inds, :], self.labels[inds])
@ -145,14 +231,13 @@ class InteractivePlotter:
for ii in inds_unlab: for ii in inds_unlab:
self.labels_cols[ii] = colors[self.labels[ii]] self.labels_cols[ii] = colors[self.labels[ii]]
else: else:
print('Not enough data - please label more classes.') print("Not enough data - please label more classes.")
def get_classifier_params(self): def get_classifier_params(self):
res = {} res = {}
if self.clf is None: if self.clf is None:
print('Model not trained!') print("Model not trained!")
else: else:
res['weights'] = self.clf.coef_.astype(np.float32) res["weights"] = self.clf.coef_.astype(np.float32)
res['biases'] = self.clf.intercept_.astype(np.float32) res["biases"] = self.clf.intercept_.astype(np.float32)
return res return res

View File

@ -8,23 +8,25 @@ Functions
`write`: Write a numpy array as a WAV file. `write`: Write a numpy array as a WAV file.
""" """
from __future__ import division, print_function, absolute_import from __future__ import absolute_import, division, print_function
import sys
import numpy
import struct
import warnings
import os import os
import struct
import sys
import warnings
import numpy
class WavFileWarning(UserWarning): class WavFileWarning(UserWarning):
pass pass
_big_endian = False _big_endian = False
WAVE_FORMAT_PCM = 0x0001 WAVE_FORMAT_PCM = 0x0001
WAVE_FORMAT_IEEE_FLOAT = 0x0003 WAVE_FORMAT_IEEE_FLOAT = 0x0003
WAVE_FORMAT_EXTENSIBLE = 0xfffe WAVE_FORMAT_EXTENSIBLE = 0xFFFE
KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT) KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT)
# assumes file pointer is immediately # assumes file pointer is immediately
@ -33,10 +35,10 @@ KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT)
def _read_fmt_chunk(fid): def _read_fmt_chunk(fid):
if _big_endian: if _big_endian:
fmt = '>' fmt = ">"
else: else:
fmt = '<' fmt = "<"
res = struct.unpack(fmt+'iHHIIHH',fid.read(20)) res = struct.unpack(fmt + "iHHIIHH", fid.read(20))
size, comp, noc, rate, sbytes, ba, bits = res size, comp, noc, rate, sbytes, ba, bits = res
if comp not in KNOWN_WAVE_FORMATS or size > 16: if comp not in KNOWN_WAVE_FORMATS or size > 16:
comp = WAVE_FORMAT_PCM comp = WAVE_FORMAT_PCM
@ -51,41 +53,42 @@ def _read_fmt_chunk(fid):
# after the 'data' id # after the 'data' id
def _read_data_chunk(fid, comp, noc, bits, mmap=False): def _read_data_chunk(fid, comp, noc, bits, mmap=False):
if _big_endian: if _big_endian:
fmt = '>i' fmt = ">i"
else: else:
fmt = '<i' fmt = "<i"
size = struct.unpack(fmt,fid.read(4))[0] size = struct.unpack(fmt, fid.read(4))[0]
bytes = bits//8 bytes = bits // 8
if bits == 8: if bits == 8:
dtype = 'u1' dtype = "u1"
else: else:
if _big_endian: if _big_endian:
dtype = '>' dtype = ">"
else: else:
dtype = '<' dtype = "<"
if comp == 1: if comp == 1:
dtype += 'i%d' % bytes dtype += "i%d" % bytes
else: else:
dtype += 'f%d' % bytes dtype += "f%d" % bytes
if not mmap: if not mmap:
data = numpy.fromstring(fid.read(size), dtype=dtype) data = numpy.fromstring(fid.read(size), dtype=dtype)
else: else:
start = fid.tell() start = fid.tell()
data = numpy.memmap(fid, dtype=dtype, mode='c', offset=start, data = numpy.memmap(
shape=(size//bytes,)) fid, dtype=dtype, mode="c", offset=start, shape=(size // bytes,)
)
fid.seek(start + size) fid.seek(start + size)
if noc > 1: if noc > 1:
data = data.reshape(-1,noc) data = data.reshape(-1, noc)
return data return data
def _skip_unknown_chunk(fid): def _skip_unknown_chunk(fid):
if _big_endian: if _big_endian:
fmt = '>i' fmt = ">i"
else: else:
fmt = '<i' fmt = "<i"
data = fid.read(4) data = fid.read(4)
size = struct.unpack(fmt, data)[0] size = struct.unpack(fmt, data)[0]
@ -95,22 +98,23 @@ def _skip_unknown_chunk(fid):
def _read_riff_chunk(fid): def _read_riff_chunk(fid):
global _big_endian global _big_endian
str1 = fid.read(4) str1 = fid.read(4)
if str1 == b'RIFX': if str1 == b"RIFX":
_big_endian = True _big_endian = True
elif str1 != b'RIFF': elif str1 != b"RIFF":
raise ValueError("Not a WAV file.") raise ValueError("Not a WAV file.")
if _big_endian: if _big_endian:
fmt = '>I' fmt = ">I"
else: else:
fmt = '<I' fmt = "<I"
fsize = struct.unpack(fmt, fid.read(4))[0] + 8 fsize = struct.unpack(fmt, fid.read(4))[0] + 8
str2 = fid.read(4) str2 = fid.read(4)
if (str2 != b'WAVE'): if str2 != b"WAVE":
raise ValueError("Not a WAV file.") raise ValueError("Not a WAV file.")
if str1 == b'RIFX': if str1 == b"RIFX":
_big_endian = True _big_endian = True
return fsize return fsize
# open a wave-file # open a wave-file
@ -145,11 +149,11 @@ def read(filename, mmap=False):
data-type determined from the file. data-type determined from the file.
""" """
if hasattr(filename,'read'): if hasattr(filename, "read"):
fid = filename fid = filename
mmap = False mmap = False
else: else:
fid = open(filename, 'rb') fid = open(filename, "rb")
try: try:
@ -169,16 +173,16 @@ def read(filename, mmap=False):
noc = 1 noc = 1
bits = 8 bits = 8
comp = WAVE_FORMAT_PCM comp = WAVE_FORMAT_PCM
while (fid.tell() < fsize): while fid.tell() < fsize:
# read the next chunk # read the next chunk
chunk_id = fid.read(4) chunk_id = fid.read(4)
if chunk_id == b'fmt ': if chunk_id == b"fmt ":
size, comp, noc, rate, sbytes, ba, bits = _read_fmt_chunk(fid) size, comp, noc, rate, sbytes, ba, bits = _read_fmt_chunk(fid)
elif chunk_id == b'fact': elif chunk_id == b"fact":
_skip_unknown_chunk(fid) _skip_unknown_chunk(fid)
elif chunk_id == b'data': elif chunk_id == b"data":
data = _read_data_chunk(fid, comp, noc, bits, mmap=mmap) data = _read_data_chunk(fid, comp, noc, bits, mmap=mmap)
elif chunk_id == b'LIST': elif chunk_id == b"LIST":
# Someday this could be handled properly but for now skip it # Someday this could be handled properly but for now skip it
_skip_unknown_chunk(fid) _skip_unknown_chunk(fid)
@ -187,13 +191,14 @@ def read(filename, mmap=False):
# warnings.warn("Chunk (non-data) not understood, skipping it.", WavFileWarning) # warnings.warn("Chunk (non-data) not understood, skipping it.", WavFileWarning)
# _skip_unknown_chunk(fid) # _skip_unknown_chunk(fid)
finally: finally:
if not hasattr(filename,'read'): if not hasattr(filename, "read"):
fid.close() fid.close()
else: else:
fid.seek(0) fid.seek(0)
return rate, data return rate, data
# Write a wave-file # Write a wave-file
# sample rate, data # sample rate, data
@ -221,26 +226,30 @@ def write(filename, rate, data):
(Nsamples, Nchannels). (Nsamples, Nchannels).
""" """
if hasattr(filename, 'write'): if hasattr(filename, "write"):
fid = filename fid = filename
else: else:
fid = open(filename, 'wb') fid = open(filename, "wb")
try: try:
# kind of numeric data in the numpy array # kind of numeric data in the numpy array
dkind = data.dtype.kind dkind = data.dtype.kind
if not (dkind == 'i' or dkind == 'f' or (dkind == 'u' and data.dtype.itemsize == 1)): if not (
dkind == "i"
or dkind == "f"
or (dkind == "u" and data.dtype.itemsize == 1)
):
raise ValueError("Unsupported data type '%s'" % data.dtype) raise ValueError("Unsupported data type '%s'" % data.dtype)
# wav header stuff # wav header stuff
# http://soundfile.sapp.org/doc/WaveFormat/ # http://soundfile.sapp.org/doc/WaveFormat/
fid.write(b'RIFF') fid.write(b"RIFF")
# placeholder for chunk size (updated later) # placeholder for chunk size (updated later)
fid.write(b'\x00\x00\x00\x00') fid.write(b"\x00\x00\x00\x00")
fid.write(b'WAVE') fid.write(b"WAVE")
# fmt chunk # fmt chunk
fid.write(b'fmt ') fid.write(b"fmt ")
if dkind == 'f': if dkind == "f":
# comp stands for compression. PCM = 1 # comp stands for compression. PCM = 1
comp = 3 comp = 3
else: else:
@ -253,7 +262,7 @@ def write(filename, rate, data):
bits = data.dtype.itemsize * 8 bits = data.dtype.itemsize * 8
# number of bytes per second, at the specified sampling rate rate, # number of bytes per second, at the specified sampling rate rate,
# bits per sample and number of channels (just needed for wav header) # bits per sample and number of channels (just needed for wav header)
sbytes = rate*(bits // 8)*noc sbytes = rate * (bits // 8) * noc
# number of bytes per sample # number of bytes per sample
ba = noc * (bits // 8) ba = noc * (bits // 8)
@ -261,11 +270,15 @@ def write(filename, rate, data):
# Write the data (16, comp, noc, etc) in the correct binary format # Write the data (16, comp, noc, etc) in the correct binary format
# for the wav header. the string format (first arg) specifies how many bytes for each # for the wav header. the string format (first arg) specifies how many bytes for each
# value. # value.
fid.write(struct.pack('<ihHIIHH', 16, comp, noc, rate, sbytes, ba, bits)) fid.write(
struct.pack("<ihHIIHH", 16, comp, noc, rate, sbytes, ba, bits)
)
# data chunk: the word 'data' followed by the size followed by the actual data # data chunk: the word 'data' followed by the size followed by the actual data
fid.write(b'data') fid.write(b"data")
fid.write(struct.pack('<i', data.nbytes)) fid.write(struct.pack("<i", data.nbytes))
if data.dtype.byteorder == '>' or (data.dtype.byteorder == '=' and sys.byteorder == 'big'): if data.dtype.byteorder == ">" or (
data.dtype.byteorder == "=" and sys.byteorder == "big"
):
data = data.byteswap() data = data.byteswap()
_array_tofile(fid, data) _array_tofile(fid, data)
@ -273,19 +286,22 @@ def write(filename, rate, data):
# position at start of the file (replacing the 4 bytes of zeros) # position at start of the file (replacing the 4 bytes of zeros)
size = fid.tell() size = fid.tell()
fid.seek(4) fid.seek(4)
fid.write(struct.pack('<i', size-8)) fid.write(struct.pack("<i", size - 8))
finally: finally:
if not hasattr(filename,'write'): if not hasattr(filename, "write"):
fid.close() fid.close()
else: else:
fid.seek(0) fid.seek(0)
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
def _array_tofile(fid, data): def _array_tofile(fid, data):
# ravel gives a c-contiguous buffer # ravel gives a c-contiguous buffer
fid.write(data.ravel().view('b').data) fid.write(data.ravel().view("b").data)
else: else:
def _array_tofile(fid, data): def _array_tofile(fid, data):
fid.write(data.tostring()) fid.write(data.tostring())

View File

@ -56,9 +56,9 @@
"source": [ "source": [
"# setup the arguments\n", "# setup the arguments\n",
"args = du.get_default_bd_args()\n", "args = du.get_default_bd_args()\n",
"args['detection_threshold'] = 0.3\n", "args[\"detection_threshold\"] = 0.3\n",
"args['time_expansion_factor'] = 1\n", "args[\"time_expansion_factor\"] = 1\n",
"args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'\n", "args[\"model_path\"] = \"models/Net2DFast_UK_same.pth.tar\"\n",
"max_duration = 2.0" "max_duration = 2.0"
] ]
}, },
@ -69,7 +69,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# load the model\n", "# load the model\n",
"model, params = du.load_model(args['model_path'])" "model, params = du.load_model(args[\"model_path\"])"
] ]
}, },
{ {
@ -86,13 +86,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# choose an audio file\n", "# choose an audio file\n",
"audio_file = 'example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav'\n", "audio_file = \"example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav\"\n",
"\n", "\n",
"# the following lines are only needed in Colab\n", "# the following lines are only needed in Colab\n",
"# alternatively you can upload your own file\n", "# alternatively you can upload your own file\n",
"#from google.colab import files\n", "# from google.colab import files\n",
"#uploaded = files.upload()\n", "# uploaded = files.upload()\n",
"#audio_file = list(uploaded.keys())[0]" "# audio_file = list(uploaded.keys())[0]"
] ]
}, },
{ {
@ -102,7 +102,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# run the model\n", "# run the model\n",
"results = du.process_file(audio_file, model, params, args, max_duration=max_duration)" "results = du.process_file(\n",
" audio_file, model, params, args, max_duration=max_duration\n",
")"
] ]
}, },
{ {
@ -144,13 +146,17 @@
} }
], ],
"source": [ "source": [
"# print summary info for the individual detections \n", "# print summary info for the individual detections\n",
"print('Results for ' + results['pred_dict']['id'])\n", "print(\"Results for \" + results[\"pred_dict\"][\"id\"])\n",
"print('{} calls detected\\n'.format(len(results['pred_dict']['annotation'])))\n", "print(\"{} calls detected\\n\".format(len(results[\"pred_dict\"][\"annotation\"])))\n",
"\n", "\n",
"print('time\\tprob\\tlfreq\\tspecies_name')\n", "print(\"time\\tprob\\tlfreq\\tspecies_name\")\n",
"for ann in results['pred_dict']['annotation']:\n", "for ann in results[\"pred_dict\"][\"annotation\"]:\n",
" print('{}\\t{}\\t{}\\t{}'.format(ann['start_time'], ann['class_prob'], ann['low_freq'], ann['class']))" " print(\n",
" \"{}\\t{}\\t{}\\t{}\".format(\n",
" ann[\"start_time\"], ann[\"class_prob\"], ann[\"low_freq\"], ann[\"class\"]\n",
" )\n",
" )"
] ]
}, },
{ {
@ -174,10 +180,16 @@
} }
], ],
"source": [ "source": [
"# read the audio file \n", "# read the audio file\n",
"sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)\n", "sampling_rate, audio = au.load_audio_file(\n",
" audio_file,\n",
" args[\"time_expansion_factor\"],\n",
" params[\"target_samp_rate\"],\n",
" params[\"scale_raw_audio\"],\n",
" max_duration=max_duration,\n",
")\n",
"duration = audio.shape[0] / sampling_rate\n", "duration = audio.shape[0] / sampling_rate\n",
"print('File duration: {} seconds'.format(duration))" "print(\"File duration: {} seconds\".format(duration))"
] ]
}, },
{ {
@ -187,7 +199,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# generate spectrogram for visualization\n", "# generate spectrogram for visualization\n",
"spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)" "spec, spec_viz = au.generate_spectrogram(\n",
" audio, sampling_rate, params, True, False\n",
")"
] ]
}, },
{ {
@ -210,12 +224,33 @@
"# display the detections on top of the spectrogram\n", "# display the detections on top of the spectrogram\n",
"# note, if the audio file is very long, this image will be very large - best to crop the audio first\n", "# note, if the audio file is very long, this image will be very large - best to crop the audio first\n",
"start_time = 0.0\n", "start_time = 0.0\n",
"detections = [ann for ann in results['pred_dict']['annotation']]\n", "detections = [ann for ann in results[\"pred_dict\"][\"annotation\"]]\n",
"fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)\n", "fig = plt.figure(\n",
"spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])\n", " 1,\n",
"viz.create_box_image(spec, fig, detections, start_time, start_time+spec_duration, spec_duration, params, spec.max()*1.1, False, True)\n", " figsize=(spec.shape[1] / 100, spec.shape[0] / 100),\n",
"plt.ylabel('Freq - kHz')\n", " dpi=100,\n",
"plt.xlabel('Time - secs')\n", " frameon=False,\n",
")\n",
"spec_duration = au.x_coords_to_time(\n",
" spec.shape[1],\n",
" sampling_rate,\n",
" params[\"fft_win_length\"],\n",
" params[\"fft_overlap\"],\n",
")\n",
"viz.create_box_image(\n",
" spec,\n",
" fig,\n",
" detections,\n",
" start_time,\n",
" start_time + spec_duration,\n",
" spec_duration,\n",
" params,\n",
" spec.max() * 1.1,\n",
" False,\n",
" True,\n",
")\n",
"plt.ylabel(\"Freq - kHz\")\n",
"plt.xlabel(\"Time - secs\")\n",
"plt.title(os.path.basename(audio_file))\n", "plt.title(os.path.basename(audio_file))\n",
"plt.show()" "plt.show()"
] ]

1337
pdm.lock generated Normal file

File diff suppressed because it is too large Load Diff

79
pyproject.toml Normal file
View File

@ -0,0 +1,79 @@
[tool.pdm]
[tool.pdm.dev-dependencies]
dev = [
"pytest>=7.2.2",
]
[project]
name = "batdetect2"
version = "0.2.0"
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
authors = [
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
]
dependencies = [
"librosa",
"matplotlib",
"numpy",
"pandas",
"scikit-learn",
"scipy",
"torch<2",
"torchaudio",
"torchvision",
"click",
]
requires-python = ">=3.8,<3.11"
readme = "README.md"
license = { text = "CC-by-nc-4" }
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
]
keywords = [
"bat",
"echolocation",
"deep learning",
"audio",
"machine learning",
"classification",
"detection",
]
[build-system]
requires = ["pdm-pep517>=1.0.0"]
build-backend = "pdm.pep517.api"
[project.scripts]
batdetect2 = "bat_detect.cli:cli"
[tool.black]
line-length = 80
[[tool.mypy.overrides]]
module = [
"librosa",
"pandas",
]
ignore_missing_imports = true
[tool.pylsp-mypy]
enabled = false
live_mode = true
strict = true
[tool.pyright]
include = [
"bat_detect",
"tests",
]
venvPath = "."
venv = ".venv"

View File

@ -7,3 +7,4 @@ scipy==1.9.3
torch==1.13.0 torch==1.13.0
torchaudio==0.13.0 torchaudio==0.13.0
torchvision==0.14.0 torchvision==0.14.0
click

View File

@ -1,67 +1,5 @@
import os """Run bat_detect.command.main() from the command line."""
import argparse from bat_detect.cli import detect
import bat_detect.utils.detector_utils as du
def main(args):
print('Loading model: ' + args['model_path'])
model, params = du.load_model(args['model_path'])
print('\nInput directory: ' + args['audio_dir'])
files = du.get_audio_files(args['audio_dir'])
print('Number of audio files: {}'.format(len(files)))
print('\nSaving results to: ' + args['ann_dir'])
# process files
error_files = []
for ii, audio_file in enumerate(files):
print('\n' + str(ii).ljust(6) + os.path.basename(audio_file))
try:
results = du.process_file(audio_file, model, params, args)
if args['save_preds_if_empty'] or (len(results['pred_dict']['annotation']) > 0):
results_path = audio_file.replace(args['audio_dir'], args['ann_dir'])
du.save_results_to_file(results, results_path)
except:
error_files.append(audio_file)
print("Error processing file!")
print('\nResults saved to: ' + args['ann_dir'])
if len(error_files) > 0:
print('\nUnable to process the follow files:')
for err in error_files:
print(' ' + err)
if __name__ == "__main__": if __name__ == "__main__":
detect()
info_str = '\nBatDetect2 - Detection and Classification\n' + \
' Assumes audio files are mono, not stereo.\n' + \
' Spaces in the input paths will throw an error. Wrap in quotes "".\n' + \
' Input files should be short in duration e.g. < 30 seconds.\n'
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument('audio_dir', type=str, help='Input directory for audio')
parser.add_argument('ann_dir', type=str, help='Output directory for where the predictions will be stored')
parser.add_argument('detection_threshold', type=float, help='Cut-off probability for detector e.g. 0.1')
parser.add_argument('--cnn_features', action='store_true', default=False, dest='cnn_features',
help='Extracts CNN call features')
parser.add_argument('--spec_features', action='store_true', default=False, dest='spec_features',
help='Extracts low level call features')
parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor',
help='The time expansion factor used for all files (default is 1)')
parser.add_argument('--quiet', action='store_true', default=False, dest='quiet',
help='Minimize output printing')
parser.add_argument('--save_preds_if_empty', action='store_true', default=False, dest='save_preds_if_empty',
help='Save empty annotation file if no detections made.')
parser.add_argument('--model_path', type=str, default='models/Net2DFast_UK_same.pth.tar',
help='Path to trained BatDetect2 model')
args = vars(parser.parse_args())
args['spec_slices'] = False # used for visualization
args['chunk_size'] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks
args['ann_dir'] = os.path.join(args['ann_dir'], '')
main(args)

View File

@ -3,62 +3,95 @@ Loads a set of annotations corresponding to a dataset and saves an image which
is the mean spectrogram for each class. is the mean spectrogram for each class.
""" """
import argparse
import os
import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import os
import argparse
import sys
import viz_helpers as vz import viz_helpers as vz
sys.path.append(os.path.join('..')) sys.path.append(os.path.join(".."))
import bat_detect.train.train_utils as tu
import bat_detect.detector.parameters as parameters import bat_detect.detector.parameters as parameters
import bat_detect.utils.audio_utils as au
import bat_detect.train.train_split as ts import bat_detect.train.train_split as ts
import bat_detect.train.train_utils as tu
import bat_detect.utils.audio_utils as au
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('audio_path', type=str, help='Input directory for audio') parser.add_argument(
parser.add_argument('op_dir', type=str, "audio_path", type=str, help="Input directory for audio"
help='Path to where single annotation json file is stored') )
parser.add_argument('--ann_file', type=str, parser.add_argument(
help='Path to where single annotation json file is stored') "op_dir",
parser.add_argument('--uk_split', type=str, default='', type=str,
help='Set as: diff or same') help="Path to where single annotation json file is stored",
parser.add_argument('--file_type', type=str, default='png', )
help='Type of image to save png or pdf') parser.add_argument(
"--ann_file",
type=str,
help="Path to where single annotation json file is stored",
)
parser.add_argument(
"--uk_split", type=str, default="", help="Set as: diff or same"
)
parser.add_argument(
"--file_type",
type=str,
default="png",
help="Type of image to save png or pdf",
)
args = vars(parser.parse_args()) args = vars(parser.parse_args())
if not os.path.isdir(args['op_dir']): if not os.path.isdir(args["op_dir"]):
os.makedirs(args['op_dir']) os.makedirs(args["op_dir"])
params = parameters.get_params(False) params = parameters.get_params(False)
params['smooth_spec'] = False params["smooth_spec"] = False
params['spec_width'] = 48 params["spec_width"] = 48
params['norm_type'] = 'log' # log, pcen params["norm_type"] = "log" # log, pcen
params['aud_pad'] = 0.005 params["aud_pad"] = 0.005
classes_to_ignore = params['classes_to_ignore'] + params['generic_class'] classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# load train annotations # load train annotations
if args['uk_split'] == '': if args["uk_split"] == "":
print('\nLoading:', args['ann_file'], '\n') print("\nLoading:", args["ann_file"], "\n")
dataset_name = os.path.basename(args['ann_file']).replace('.json', '') dataset_name = os.path.basename(args["ann_file"]).replace(".json", "")
datasets = [] datasets = []
datasets.append(tu.get_blank_dataset_dict(dataset_name, False, args['ann_file'], args['audio_path'])) datasets.append(
tu.get_blank_dataset_dict(
dataset_name, False, args["ann_file"], args["audio_path"]
)
)
else: else:
# load uk data - special case # load uk data - special case
print('\nLoading:', args['uk_split'], '\n') print("\nLoading:", args["uk_split"], "\n")
dataset_name = 'uk_' + args['uk_split'] # should be uk_diff, or uk_same dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
datasets, _ = ts.get_train_test_data(args['ann_file'], args['audio_path'], args['uk_split'], load_extra=False) datasets, _ = ts.get_train_test_data(
args["ann_file"],
args["audio_path"],
args["uk_split"],
load_extra=False,
)
anns, class_names, _ = tu.load_set_of_anns(datasets, classes_to_ignore, params['events_of_interest']) anns, class_names, _ = tu.load_set_of_anns(
datasets, classes_to_ignore, params["events_of_interest"]
)
class_names_order = range(len(class_names)) class_names_order = range(len(class_names))
x_train, y_train = vz.load_data(anns, params, class_names, smooth_spec=params['smooth_spec'], norm_type=params['norm_type']) x_train, y_train = vz.load_data(
anns,
params,
class_names,
smooth_spec=params["smooth_spec"],
norm_type=params["norm_type"],
)
op_file_name = os.path.join(args['op_dir'], dataset_name + '.' + args['file_type']) op_file_name = os.path.join(
vz.save_summary_image(x_train, y_train, class_names, params, op_file_name, class_names_order) args["op_dir"], dataset_name + "." + args["file_type"]
print('\nImage saved to:', op_file_name) )
vz.save_summary_image(
x_train, y_train, class_names, params, op_file_name, class_names_order
)
print("\nImage saved to:", op_file_name)

View File

@ -7,24 +7,27 @@ Will save images with:
3) spectrogram with predicted boxes 3) spectrogram with predicted boxes
""" """
import numpy as np
import sys
import os
import argparse import argparse
import matplotlib.pyplot as plt
import json import json
import os
import sys
sys.path.append(os.path.join('..')) import matplotlib.pyplot as plt
import numpy as np
sys.path.append(os.path.join(".."))
import bat_detect.evaluate.evaluate_models as evlm import bat_detect.evaluate.evaluate_models as evlm
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz import bat_detect.utils.plot_utils as viz
import bat_detect.utils.audio_utils as au
def filter_anns(anns, start_time, stop_time): def filter_anns(anns, start_time, stop_time):
anns_op = [] anns_op = []
for aa in anns: for aa in anns:
if (aa['start_time'] >= start_time) and (aa['start_time'] < stop_time-0.02): if (aa["start_time"] >= start_time) and (
aa["start_time"] < stop_time - 0.02
):
anns_op.append(aa) anns_op.append(aa)
return anns_op return anns_op
@ -32,85 +35,175 @@ def filter_anns(anns, start_time, stop_time):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('audio_file', type=str, help='Path to audio file') parser.add_argument("audio_file", type=str, help="Path to audio file")
parser.add_argument('model_path', type=str, help='Path to BatDetect model') parser.add_argument("model_path", type=str, help="Path to BatDetect model")
parser.add_argument('--ann_file', type=str, default='', help='Path to annotation file') parser.add_argument(
parser.add_argument('--op_dir', type=str, default='plots/', "--ann_file", type=str, default="", help="Path to annotation file"
help='Output directory for plots') )
parser.add_argument('--file_type', type=str, default='png', parser.add_argument(
help='Type of image to save png or pdf') "--op_dir",
parser.add_argument('--title_text', type=str, default='', type=str,
help='Text to add as title of plots') default="plots/",
parser.add_argument('--detection_threshold', type=float, default=0.2, help="Output directory for plots",
help='Threshold for output detections') )
parser.add_argument('--start_time', type=float, default=0.0, parser.add_argument(
help='Start time for cropped file') "--file_type",
parser.add_argument('--stop_time', type=float, default=0.5, type=str,
help='End time for cropped file') default="png",
parser.add_argument('--time_expansion_factor', type=int, default=1, help="Type of image to save png or pdf",
help='Time expansion factor') )
parser.add_argument(
"--title_text",
type=str,
default="",
help="Text to add as title of plots",
)
parser.add_argument(
"--detection_threshold",
type=float,
default=0.2,
help="Threshold for output detections",
)
parser.add_argument(
"--start_time",
type=float,
default=0.0,
help="Start time for cropped file",
)
parser.add_argument(
"--stop_time",
type=float,
default=0.5,
help="End time for cropped file",
)
parser.add_argument(
"--time_expansion_factor",
type=int,
default=1,
help="Time expansion factor",
)
args_cmd = vars(parser.parse_args()) args_cmd = vars(parser.parse_args())
# load the model # load the model
bd_args = du.get_default_bd_args() bd_args = du.get_default_run_config()
model, params_bd = du.load_model(args_cmd['model_path']) model, params_bd = du.load_model(args_cmd["model_path"])
bd_args['detection_threshold'] = args_cmd['detection_threshold'] bd_args["detection_threshold"] = args_cmd["detection_threshold"]
bd_args['time_expansion_factor'] = args_cmd['time_expansion_factor'] bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
# load the annotation if it exists # load the annotation if it exists
gt_present = False gt_present = False
if args_cmd['ann_file'] != '': if args_cmd["ann_file"] != "":
if os.path.isfile(args_cmd['ann_file']): if os.path.isfile(args_cmd["ann_file"]):
with open(args_cmd['ann_file']) as da: with open(args_cmd["ann_file"]) as da:
gt_anns = json.load(da) gt_anns = json.load(da)
gt_anns = filter_anns(gt_anns['annotation'], args_cmd['start_time'], args_cmd['stop_time']) gt_anns = filter_anns(
gt_anns["annotation"],
args_cmd["start_time"],
args_cmd["stop_time"],
)
gt_present = True gt_present = True
else: else:
print('Annotation file not found: ', args_cmd['ann_file']) print("Annotation file not found: ", args_cmd["ann_file"])
# load the audio file # load the audio file
if not os.path.isfile(args_cmd['audio_file']): if not os.path.isfile(args_cmd["audio_file"]):
print('Audio file not found: ', args_cmd['audio_file']) print("Audio file not found: ", args_cmd["audio_file"])
sys.exit() sys.exit()
# load audio and crop # load audio and crop
print('\nProcessing: ' + os.path.basename(args_cmd['audio_file'])) print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
print('\nOutput directory: ' + args_cmd['op_dir']) print("\nOutput directory: " + args_cmd["op_dir"])
sampling_rate, audio = au.load_audio_file(args_cmd['audio_file'], args_cmd['time_exp'], sampling_rate, audio = au.load_audio(
params_bd['target_samp_rate'], params_bd['scale_raw_audio']) args_cmd["audio_file"],
st_samp = int(sampling_rate*args_cmd['start_time']) args_cmd["time_exp"],
en_samp = int(sampling_rate*args_cmd['stop_time']) params_bd["target_samp_rate"],
params_bd["scale_raw_audio"],
)
st_samp = int(sampling_rate * args_cmd["start_time"])
en_samp = int(sampling_rate * args_cmd["stop_time"])
if en_samp > audio.shape[0]: if en_samp > audio.shape[0]:
audio = np.hstack((audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype))) audio = np.hstack(
(audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype))
)
audio = audio[st_samp:en_samp] audio = audio[st_samp:en_samp]
duration = audio.shape[0] / sampling_rate duration = audio.shape[0] / sampling_rate
print('File duration: {} seconds'.format(duration)) print("File duration: {} seconds".format(duration))
# create spec for viz # create spec for viz
spec, _ = au.generate_spectrogram(audio, sampling_rate, params_bd, True, False) spec, _ = au.generate_spectrogram(
audio, sampling_rate, params_bd, True, False
)
run_config = {
**params_bd,
**bd_args,
}
# run model and filter detections so only keep ones in relevant time range # run model and filter detections so only keep ones in relevant time range
results = du.process_file(args_cmd['audio_file'], model, params_bd, bd_args) results = du.process_file(args_cmd["audio_file"], model, run_config)
pred_anns = filter_anns(results['pred_dict']['annotation'], args_cmd['start_time'], args_cmd['stop_time']) pred_anns = filter_anns(
print(len(pred_anns), 'Detections') results["pred_dict"]["annotation"],
args_cmd["start_time"],
args_cmd["stop_time"],
)
print(len(pred_anns), "Detections")
# save output # save output
if not os.path.isdir(args_cmd['op_dir']): if not os.path.isdir(args_cmd["op_dir"]):
os.makedirs(args_cmd['op_dir']) os.makedirs(args_cmd["op_dir"])
# create output file names # create output file names
op_path_clean = os.path.basename(args_cmd['audio_file'])[:-4] + '_clean.' + args_cmd['file_type'] op_path_clean = (
op_path_clean = os.path.join(args_cmd['op_dir'], op_path_clean) os.path.basename(args_cmd["audio_file"])[:-4]
op_path_pred = os.path.basename(args_cmd['audio_file'])[:-4] + '_pred.' + args_cmd['file_type'] + "_clean."
op_path_pred = os.path.join(args_cmd['op_dir'], op_path_pred) + args_cmd["file_type"]
)
op_path_clean = os.path.join(args_cmd["op_dir"], op_path_clean)
op_path_pred = (
os.path.basename(args_cmd["audio_file"])[:-4]
+ "_pred."
+ args_cmd["file_type"]
)
op_path_pred = os.path.join(args_cmd["op_dir"], op_path_pred)
# create and save iamges # create and save iamges
viz.save_ann_spec(op_path_clean, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', None) viz.save_ann_spec(
viz.save_ann_spec(op_path_pred, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', pred_anns) op_path_clean,
spec,
params_bd["min_freq"],
params_bd["max_freq"],
duration,
args_cmd["start_time"],
"",
None,
)
viz.save_ann_spec(
op_path_pred,
spec,
params_bd["min_freq"],
params_bd["max_freq"],
duration,
args_cmd["start_time"],
"",
pred_anns,
)
if gt_present: if gt_present:
op_path_gt = os.path.basename(args_cmd['audio_file'])[:-4] + '_gt.' + args_cmd['file_type'] op_path_gt = (
op_path_gt = os.path.join(args_cmd['op_dir'], op_path_gt) os.path.basename(args_cmd["audio_file"])[:-4]
viz.save_ann_spec(op_path_gt, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', gt_anns) + "_gt."
+ args_cmd["file_type"]
)
op_path_gt = os.path.join(args_cmd["op_dir"], op_path_gt)
viz.save_ann_spec(
op_path_gt,
spec,
params_bd["min_freq"],
params_bd["max_freq"],
duration,
args_cmd["start_time"],
"",
gt_anns,
)

View File

@ -8,163 +8,263 @@ Notes:
Best to use system one - see ffmpeg_path. Best to use system one - see ffmpeg_path.
""" """
from scipy.io import wavfile import argparse
import os import os
import shutil import shutil
import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import argparse from scipy.io import wavfile
import sys sys.path.append(os.path.join(".."))
sys.path.append(os.path.join('..'))
import bat_detect.detector.parameters as parameters import bat_detect.detector.parameters as parameters
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
import bat_detect.utils.plot_utils as viz
import bat_detect.utils.detector_utils as du import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('audio_file', type=str, help='Path to input audio file') parser.add_argument("audio_file", type=str, help="Path to input audio file")
parser.add_argument('model_path', type=str, help='Path to trained BatDetect model') parser.add_argument(
parser.add_argument('--op_dir', type=str, default='generated_vids/', help='Path to output directory') "model_path", type=str, help="Path to trained BatDetect model"
parser.add_argument('--no_detector', action='store_true', help='Do not run detector') )
parser.add_argument('--plot_class_names_off', action='store_true', help='Do not plot class names') parser.add_argument(
parser.add_argument('--disable_axis', action='store_true', help='Do not plot axis') "--op_dir",
parser.add_argument('--detection_threshold', type=float, default=0.2, help='Cut-off probability for detector') type=str,
parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor', default="generated_vids/",
help='The time expansion factor used for all files (default is 1)') help="Path to output directory",
)
parser.add_argument(
"--no_detector", action="store_true", help="Do not run detector"
)
parser.add_argument(
"--plot_class_names_off",
action="store_true",
help="Do not plot class names",
)
parser.add_argument(
"--disable_axis", action="store_true", help="Do not plot axis"
)
parser.add_argument(
"--detection_threshold",
type=float,
default=0.2,
help="Cut-off probability for detector",
)
parser.add_argument(
"--time_expansion_factor",
type=int,
default=1,
dest="time_expansion_factor",
help="The time expansion factor used for all files (default is 1)",
)
args_cmd = vars(parser.parse_args()) args_cmd = vars(parser.parse_args())
# file of interest # file of interest
audio_file = args_cmd['audio_file'] audio_file = args_cmd["audio_file"]
op_dir = args_cmd['op_dir'] op_dir = args_cmd["op_dir"]
op_str = '_output' op_str = "_output"
ffmpeg_path = '/usr/bin/' ffmpeg_path = "/usr/bin/"
if not os.path.isfile(audio_file): if not os.path.isfile(audio_file):
print('Audio file not found: ', audio_file) print("Audio file not found: ", audio_file)
sys.exit() sys.exit()
if not os.path.isfile(args_cmd['model_path']): if not os.path.isfile(args_cmd["model_path"]):
print('Model not found: ', model_path) print("Model not found: ", model_path)
sys.exit() sys.exit()
start_time = 0.0 start_time = 0.0
duration = 0.5 duration = 0.5
reveal_boxes = True # makes the boxes appear one at a time reveal_boxes = True # makes the boxes appear one at a time
fps = 24 fps = 24
dpi = 100 dpi = 100
op_dir_tmp = os.path.join(op_dir, 'op_tmp_vids', '') op_dir_tmp = os.path.join(op_dir, "op_tmp_vids", "")
if not os.path.isdir(op_dir_tmp): if not os.path.isdir(op_dir_tmp):
os.makedirs(op_dir_tmp) os.makedirs(op_dir_tmp)
if not os.path.isdir(op_dir): if not os.path.isdir(op_dir):
os.makedirs(op_dir) os.makedirs(op_dir)
params = parameters.get_params(False) params = parameters.get_params(False)
args = du.get_default_bd_args() args = du.get_default_run_config()
args['time_expansion_factor'] = args_cmd['time_expansion_factor'] args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
args['detection_threshold'] = args_cmd['detection_threshold'] args["detection_threshold"] = args_cmd["detection_threshold"]
# load audio file # load audio file
print('\nProcessing: ' + os.path.basename(audio_file)) print("\nProcessing: " + os.path.basename(audio_file))
print('\nOutput directory: ' + op_dir) print("\nOutput directory: " + op_dir)
sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate']) sampling_rate, audio = au.load_audio(
audio = audio[int(sampling_rate*start_time):int(sampling_rate*start_time + sampling_rate*duration)] audio_file, args["time_expansion_factor"], params["target_samp_rate"]
)
audio = audio[
int(sampling_rate * start_time) : int(
sampling_rate * start_time + sampling_rate * duration
)
]
audio_orig = audio.copy() audio_orig = audio.copy()
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], audio = au.pad_audio(
params['fft_overlap'], params['resize_factor'], audio,
params['spec_divide_factor']) sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
params["resize_factor"],
params["spec_divide_factor"],
)
# generate spectrogram # generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True) spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True)
max_val = spec.max()*1.1 max_val = spec.max() * 1.1
if not args_cmd["no_detector"]:
print(" Loading model and running detector on entire file ...")
model, det_params = du.load_model(args_cmd["model_path"])
det_params["detection_threshold"] = args["detection_threshold"]
if not args_cmd['no_detector']: run_config = {
print(' Loading model and running detector on entire file ...') **det_params,
model, det_params = du.load_model(args_cmd['model_path']) **args,
det_params['detection_threshold'] = args['detection_threshold'] }
results = du.process_file(audio_file, model, det_params, args) results = du.process_file(audio_file, model, run_config)
print(' Processing detections and plotting ...') print(" Processing detections and plotting ...")
detections = [] detections = []
for bb in results['pred_dict']['annotation']: for bb in results["pred_dict"]["annotation"]:
if (bb['start_time'] >= start_time) and (bb['end_time'] < start_time+duration): if (bb["start_time"] >= start_time) and (
bb["end_time"] < start_time + duration
):
detections.append(bb) detections.append(bb)
# plot boxes # plot boxes
fig = plt.figure(1, figsize=(spec.shape[1]/dpi, spec.shape[0]/dpi), dpi=dpi) fig = plt.figure(
duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) 1, figsize=(spec.shape[1] / dpi, spec.shape[0] / dpi), dpi=dpi
viz.create_box_image(spec, fig, detections, start_time, start_time+duration, duration, params, max_val, )
plot_class_names=not args_cmd['plot_class_names_off']) duration = au.x_coords_to_time(
op_im_file_boxes = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '_boxes.png') spec.shape[1],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
viz.create_box_image(
spec,
fig,
detections,
start_time,
start_time + duration,
duration,
params,
max_val,
plot_class_names=not args_cmd["plot_class_names_off"],
)
op_im_file_boxes = os.path.join(
op_dir, os.path.basename(audio_file)[:-4] + op_str + "_boxes.png"
)
fig.savefig(op_im_file_boxes, dpi=dpi) fig.savefig(op_im_file_boxes, dpi=dpi)
plt.close(1) plt.close(1)
spec_with_boxes = plt.imread(op_im_file_boxes) spec_with_boxes = plt.imread(op_im_file_boxes)
print(" Saving audio file ...")
print(' Saving audio file ...') if args["time_expansion_factor"] == 1:
if args['time_expansion_factor']==1: sampling_rate_op = int(sampling_rate / 10.0)
sampling_rate_op = int(sampling_rate/10.0)
else: else:
sampling_rate_op = sampling_rate sampling_rate_op = sampling_rate
op_audio_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.wav') op_audio_file = os.path.join(
op_dir, os.path.basename(audio_file)[:-4] + op_str + ".wav"
)
wavfile.write(op_audio_file, sampling_rate_op, audio_orig) wavfile.write(op_audio_file, sampling_rate_op, audio_orig)
print(" Saving image ...")
print(' Saving image ...') op_im_file = os.path.join(
op_im_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.png') op_dir, os.path.basename(audio_file)[:-4] + op_str + ".png"
plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap='plasma') )
plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap="plasma")
spec_blank = plt.imread(op_im_file) spec_blank = plt.imread(op_im_file)
# create figure # create figure
freq_scale = 1000 # turn Hz to kHz freq_scale = 1000 # turn Hz to kHz
min_freq = params['min_freq']//freq_scale min_freq = params["min_freq"] // freq_scale
max_freq = params['max_freq']//freq_scale max_freq = params["max_freq"] // freq_scale
y_extent = [0, duration, min_freq, max_freq] y_extent = [0, duration, min_freq, max_freq]
print(' Saving video frames ...') print(" Saving video frames ...")
# save images that will be combined into video # save images that will be combined into video
# will either plot with or without boxes # will either plot with or without boxes
for ii, col in enumerate(np.linspace(0, spec.shape[1]-1, int(fps*duration*10))): for ii, col in enumerate(
if not args_cmd['no_detector']: np.linspace(0, spec.shape[1] - 1, int(fps * duration * 10))
):
if not args_cmd["no_detector"]:
spec_op = spec_with_boxes.copy() spec_op = spec_with_boxes.copy()
if ii > 0: if ii > 0:
spec_op[:, int(col), :] = 1.0 spec_op[:, int(col), :] = 1.0
if reveal_boxes: if reveal_boxes:
spec_op[:, int(col)+1:, :] = spec_blank[:, int(col)+1:, :] spec_op[:, int(col) + 1 :, :] = spec_blank[
:, int(col) + 1 :, :
]
elif ii == 0 and reveal_boxes: elif ii == 0 and reveal_boxes:
spec_op = spec_blank spec_op = spec_blank
if not args_cmd['disable_axis']: if not args_cmd["disable_axis"]:
plt.close('all') plt.close("all")
fig = plt.figure(ii, figsize=(1.2*(spec_op.shape[1]/dpi), 1.5*(spec_op.shape[0]/dpi)), dpi=dpi) fig = plt.figure(
plt.xlabel('Time - seconds') ii,
plt.ylabel('Frequency - kHz') figsize=(
plt.imshow(spec_op, vmin=0, vmax=1.0, cmap='plasma', extent=y_extent, aspect='auto') 1.2 * (spec_op.shape[1] / dpi),
1.5 * (spec_op.shape[0] / dpi),
),
dpi=dpi,
)
plt.xlabel("Time - seconds")
plt.ylabel("Frequency - kHz")
plt.imshow(
spec_op,
vmin=0,
vmax=1.0,
cmap="plasma",
extent=y_extent,
aspect="auto",
)
plt.tight_layout() plt.tight_layout()
fig.savefig(op_dir_tmp + str(ii).zfill(4) + '.png', dpi=dpi) fig.savefig(op_dir_tmp + str(ii).zfill(4) + ".png", dpi=dpi)
else: else:
plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=1.0, cmap='plasma') plt.imsave(
op_dir_tmp + str(ii).zfill(4) + ".png",
spec_op,
vmin=0,
vmax=1.0,
cmap="plasma",
)
else: else:
spec_op = spec.copy() spec_op = spec.copy()
if ii > 0: if ii > 0:
spec_op[:, int(col)] = max_val spec_op[:, int(col)] = max_val
plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=max_val, cmap='plasma') plt.imsave(
op_dir_tmp + str(ii).zfill(4) + ".png",
spec_op,
vmin=0,
vmax=max_val,
cmap="plasma",
)
print(" Creating video ...")
print(' Creating video ...') op_vid_file = os.path.join(
op_vid_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.avi') op_dir, os.path.basename(audio_file)[:-4] + op_str + ".avi"
ffmpeg_cmd = 'ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 ' \ )
'-crf 25 -pix_fmt yuv420p -acodec copy {}'.format(fps, spec.shape[1], spec.shape[0], op_dir_tmp, op_audio_file, op_vid_file) ffmpeg_cmd = (
"ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 "
"-crf 25 -pix_fmt yuv420p -acodec copy {}".format(
fps,
spec.shape[1],
spec.shape[0],
op_dir_tmp,
op_audio_file,
op_vid_file,
)
)
ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd
os.system(ffmpeg_cmd) os.system(ffmpeg_cmd)
print(' Deleting temporary files ...') print(" Deleting temporary files ...")
if os.path.isdir(op_dir_tmp): if os.path.isdir(op_dir_tmp):
shutil.rmtree(op_dir_tmp) shutil.rmtree(op_dir_tmp)

View File

@ -1,41 +1,70 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import os import os
import sys import sys
sys.path.append(os.path.join('..'))
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage
sys.path.append(os.path.join(".."))
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
def generate_spectrogram_data(audio, sampling_rate, params, norm_type='log', smooth_spec=False): def generate_spectrogram_data(
max_freq = round(params['max_freq']*params['fft_win_length']) audio, sampling_rate, params, norm_type="log", smooth_spec=False
min_freq = round(params['min_freq']*params['fft_win_length']) ):
max_freq = round(params["max_freq"] * params["fft_win_length"])
min_freq = round(params["min_freq"] * params["fft_win_length"])
# create spectrogram - numpy # create spectrogram - numpy
spec = au.gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']) spec = au.gen_mag_spectrogram(
#spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy() audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
)
# spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy()
if spec.shape[0] < max_freq: if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0] freq_pad = max_freq - spec.shape[0]
spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec)) spec = np.vstack(
spec = spec[-max_freq:spec.shape[0]-min_freq, :] (np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec)
)
spec = spec[-max_freq : spec.shape[0] - min_freq, :]
if norm_type == 'log': if norm_type == "log":
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) log_scaling = (
2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
).sum()
)
)
##log_scaling = 0.01 ##log_scaling = 0.01
spec = np.log(1.0 + log_scaling*spec).astype(np.float32) spec = np.log(1.0 + log_scaling * spec).astype(np.float32)
elif norm_type == 'pcen': elif norm_type == "pcen":
spec = au.pcen(spec, sampling_rate) spec = au.pcen(spec, sampling_rate)
else: else:
pass pass
if smooth_spec: if smooth_spec:
spec = ndimage.gaussian_filter(spec, 1) spec = ndimage.gaussian_filter(spec, 1)
return spec return spec
def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', extract_bg=False): def load_data(
anns,
params,
class_names,
smooth_spec=False,
norm_type="log",
extract_bg=False,
):
specs = [] specs = []
labels = [] labels = []
coords = [] coords = []
@ -43,67 +72,106 @@ def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', ext
sampling_rates = [] sampling_rates = []
file_names = [] file_names = []
for cur_file in anns: for cur_file in anns:
sampling_rate, audio_orig = au.load_audio_file(cur_file['file_path'], cur_file['time_exp'], sampling_rate, audio_orig = au.load_audio(
params['target_samp_rate'], params['scale_raw_audio']) cur_file["file_path"],
cur_file["time_exp"],
params["target_samp_rate"],
params["scale_raw_audio"],
)
for ann in cur_file['annotation']: for ann in cur_file["annotation"]:
if ann['class'] not in params['classes_to_ignore'] and ann['class'] in class_names: if (
ann["class"] not in params["classes_to_ignore"]
and ann["class"] in class_names
):
# clip out of bounds # clip out of bounds
if ann['low_freq'] < params['min_freq']: if ann["low_freq"] < params["min_freq"]:
ann['low_freq'] = params['min_freq'] ann["low_freq"] = params["min_freq"]
if ann['high_freq'] > params['max_freq']: if ann["high_freq"] > params["max_freq"]:
ann['high_freq'] = params['max_freq'] ann["high_freq"] = params["max_freq"]
# load cropped audio # load cropped audio
start_samp_diff = int(sampling_rate*ann['start_time']) - int(sampling_rate*params['aud_pad']) start_samp_diff = int(sampling_rate * ann["start_time"]) - int(
sampling_rate * params["aud_pad"]
)
start_samp = np.maximum(0, start_samp_diff) start_samp = np.maximum(0, start_samp_diff)
end_samp = np.minimum(audio_orig.shape[0], int(sampling_rate*ann['end_time'])*2 + int(sampling_rate*params['aud_pad'])) end_samp = np.minimum(
audio_orig.shape[0],
int(sampling_rate * ann["end_time"]) * 2
+ int(sampling_rate * params["aud_pad"]),
)
audio = audio_orig[start_samp:end_samp] audio = audio_orig[start_samp:end_samp]
if start_samp_diff < 0: if start_samp_diff < 0:
# need to pad at start if the call is at the very begining # need to pad at start if the call is at the very begining
audio = np.hstack((np.zeros(-start_samp_diff, dtype=np.float32), audio)) audio = np.hstack(
(np.zeros(-start_samp_diff, dtype=np.float32), audio)
)
nfft = int(params['fft_win_length']*sampling_rate) nfft = int(params["fft_win_length"] * sampling_rate)
noverlap = int(params['fft_overlap']*nfft) noverlap = int(params["fft_overlap"] * nfft)
max_samps = params['spec_width']*(nfft - noverlap) + noverlap max_samps = params["spec_width"] * (nfft - noverlap) + noverlap
if max_samps > audio.shape[0]: if max_samps > audio.shape[0]:
audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0]))) audio = np.hstack(
(audio, np.zeros(max_samps - audio.shape[0]))
)
audio = audio[:max_samps].astype(np.float32) audio = audio[:max_samps].astype(np.float32)
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], audio = au.pad_audio(
params['fft_overlap'], params['resize_factor'], audio,
params['spec_divide_factor']) sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
params["resize_factor"],
params["spec_divide_factor"],
)
# generate spectrogram # generate spectrogram
spec = generate_spectrogram_data(audio, sampling_rate, params, norm_type, smooth_spec)[:, :params['spec_width']] spec = generate_spectrogram_data(
audio, sampling_rate, params, norm_type, smooth_spec
)[:, : params["spec_width"]]
specs.append(spec[np.newaxis, ...]) specs.append(spec[np.newaxis, ...])
labels.append(ann['class']) labels.append(ann["class"])
audios.append(audio) audios.append(audio)
sampling_rates.append(sampling_rate) sampling_rates.append(sampling_rate)
file_names.append(cur_file['file_path']) file_names.append(cur_file["file_path"])
# position in crop # position in crop
x1 = int(au.time_to_x_coords(np.array(params['aud_pad']), sampling_rate, params['fft_win_length'], params['fft_overlap'])) x1 = int(
y1 = (ann['low_freq'] - params['min_freq']) * params['fft_win_length'] au.time_to_x_coords(
np.array(params["aud_pad"]),
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
)
y1 = (ann["low_freq"] - params["min_freq"]) * params[
"fft_win_length"
]
coords.append((y1, x1)) coords.append((y1, x1))
_, file_ids = np.unique(file_names, return_inverse=True) _, file_ids = np.unique(file_names, return_inverse=True)
labels = np.array([class_names.index(ll) for ll in labels]) labels = np.array([class_names.index(ll) for ll in labels])
#return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names # return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names
return np.vstack(specs), labels return np.vstack(specs), labels
def save_summary_image(specs, labels, species_names, params, op_file_name='plots/all_species.png', order=None): def save_summary_image(
specs,
labels,
species_names,
params,
op_file_name="plots/all_species.png",
order=None,
):
# takes the mean for each class and plots it on a grid # takes the mean for each class and plots it on a grid
mean_specs = [] mean_specs = []
max_band = [] max_band = []
for ii in range(len(species_names)): for ii in range(len(species_names)):
inds = np.where(labels==ii)[0] inds = np.where(labels == ii)[0]
mu = specs[inds, :].mean(0) mu = specs[inds, :].mean(0)
max_band.append(np.argmax(mu.sum(1))) max_band.append(np.argmax(mu.sum(1)))
mean_specs.append(mu) mean_specs.append(mu)
@ -113,11 +181,21 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots
order = np.arange(len(species_names)) order = np.arange(len(species_names))
max_cols = 6 max_cols = 6
nrows = int(np.ceil(len(species_names)/max_cols)) nrows = int(np.ceil(len(species_names) / max_cols))
ncols = np.minimum(len(species_names), max_cols) ncols = np.minimum(len(species_names), max_cols)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.3, nrows*6), gridspec_kw = {'wspace':0, 'hspace':0.2}) fig, ax = plt.subplots(
spec_min_max = (0, mean_specs[0].shape[1], params['min_freq']/1000, params['max_freq']/1000) nrows=nrows,
ncols=ncols,
figsize=(ncols * 3.3, nrows * 6),
gridspec_kw={"wspace": 0, "hspace": 0.2},
)
spec_min_max = (
0,
mean_specs[0].shape[1],
params["min_freq"] / 1000,
params["max_freq"] / 1000,
)
ii = 0 ii = 0
for row in ax: for row in ax:
@ -126,17 +204,22 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots
for col in row: for col in row:
if ii >= len(species_names): if ii >= len(species_names):
col.axis('off') col.axis("off")
else: else:
inds = np.where(labels==order[ii])[0] inds = np.where(labels == order[ii])[0]
col.imshow(mean_specs[order[ii]], extent=spec_min_max, cmap='plasma', aspect='equal') col.imshow(
col.grid(color='w', alpha=0.3, linewidth=0.3) mean_specs[order[ii]],
extent=spec_min_max,
cmap="plasma",
aspect="equal",
)
col.grid(color="w", alpha=0.3, linewidth=0.3)
col.set_xticks([]) col.set_xticks([])
col.title.set_text(str(ii+1) + ' ' + species_names[order[ii]]) col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
col.tick_params(axis='both', which='major', labelsize=7) col.tick_params(axis="both", which="major", labelsize=7)
ii += 1 ii += 1
#plt.tight_layout() # plt.tight_layout()
#plt.show() # plt.show()
plt.savefig(op_file_name) plt.savefig(op_file_name)
plt.close('all') plt.close("all")

0
tests/__init__.py Normal file
View File

253
tests/test_api.py Normal file
View File

@ -0,0 +1,253 @@
"""Test bat detect module API."""
import os
from glob import glob
import numpy as np
import torch
from torch import nn
from bat_detect import api
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
def test_load_model_with_default_params():
"""Test loading model with default parameters."""
model, params = api.load_model()
assert model is not None
assert isinstance(model, nn.Module)
assert params is not None
assert isinstance(params, dict)
assert "model_name" in params
assert "num_filters" in params
assert "emb_dim" in params
assert "ip_height" in params
assert params["model_name"] == "Net2DFast"
assert params["num_filters"] == 128
assert params["emb_dim"] == 0
assert params["ip_height"] == 128
assert params["resize_factor"] == 0.5
assert len(params["class_names"]) == 17
def test_list_audio_files():
"""Test listing audio files."""
audio_files = api.list_audio_files(TEST_DATA_DIR)
assert len(audio_files) == 3
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
def test_load_audio():
"""Test loading audio."""
audio = api.load_audio(TEST_DATA[0])
assert audio is not None
assert isinstance(audio, np.ndarray)
assert audio.shape == (128000,)
def test_generate_spectrogram():
"""Test generating spectrogram."""
audio = api.load_audio(TEST_DATA[0])
spectrogram = api.generate_spectrogram(audio)
assert spectrogram is not None
assert isinstance(spectrogram, torch.Tensor)
assert spectrogram.shape == (1, 1, 128, 512)
def test_get_default_config():
"""Test getting default configuration."""
config = api.get_config()
assert config is not None
assert isinstance(config, dict)
assert config["target_samp_rate"] == 256000
assert config["fft_win_length"] == 0.002
assert config["fft_overlap"] == 0.75
assert config["resize_factor"] == 0.5
assert config["spec_divide_factor"] == 32
assert config["spec_height"] == 256
assert config["spec_scale"] == "pcen"
assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 0
assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1
assert config["top_n"] == 3
assert config["return_raw_preds"] is False
assert config["max_duration"] is None
assert config["nms_kernel_size"] == 9
assert config["max_freq"] == 120000
assert config["min_freq"] == 10000
assert config["nms_top_k_per_sec"] == 200
assert config["quiet"] is True
assert config["chunk_size"] == 3
assert config["cnn_features"] is False
assert config["spec_features"] is False
assert config["spec_slices"] is False
def test_api_exposes_default_model():
"""Test that API exposes default model."""
assert hasattr(api, "model")
assert isinstance(api.model, nn.Module)
assert type(api.model).__name__ == "Net2DFast"
# Check that model has expected attributes
assert api.model.num_classes == 17
assert api.model.num_filts == 128
assert api.model.emb_dim == 0
assert api.model.ip_height_rs == 128
assert api.model.resize_factor == 0.5
def test_api_exposes_default_config():
"""Test that API exposes default configuration."""
assert hasattr(api, "config")
assert isinstance(api.config, dict)
assert api.config["target_samp_rate"] == 256000
assert api.config["fft_win_length"] == 0.002
assert api.config["fft_overlap"] == 0.75
assert api.config["resize_factor"] == 0.5
assert api.config["spec_divide_factor"] == 32
assert api.config["spec_height"] == 256
assert api.config["spec_scale"] == "pcen"
assert api.config["denoise_spec_avg"] is True
assert api.config["max_scale_spec"] is False
assert api.config["scale_raw_audio"] is False
assert len(api.config["class_names"]) == 17
assert api.config["detection_threshold"] == 0.01
assert api.config["time_expansion"] == 1
assert api.config["top_n"] == 3
assert api.config["return_raw_preds"] is False
assert api.config["max_duration"] is None
assert api.config["nms_kernel_size"] == 9
assert api.config["max_freq"] == 120000
assert api.config["min_freq"] == 10000
assert api.config["nms_top_k_per_sec"] == 200
assert api.config["quiet"] is True
assert api.config["chunk_size"] == 3
assert api.config["cnn_features"] is False
assert api.config["spec_features"] is False
assert api.config["spec_slices"] is False
def test_process_file_with_default_model():
"""Test processing file with model."""
predictions = api.process_file(TEST_DATA[0])
assert predictions is not None
assert isinstance(predictions, dict)
assert "pred_dict" in predictions
# By default will not return other features
assert "spec_feats" not in predictions
assert "spec_feat_names" not in predictions
assert "cnn_feats" not in predictions
assert "cnn_feat_names" not in predictions
assert "spec_slices" not in predictions
# Check that predictions are returned
assert isinstance(predictions["pred_dict"], dict)
pred_dict = predictions["pred_dict"]
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
assert pred_dict["annotated"] is False
assert pred_dict["issues"] is False
assert pred_dict["notes"] == "Automatically generated."
assert pred_dict["time_exp"] == 1
assert pred_dict["duration"] == 0.5
assert pred_dict["class_name"] is not None
assert len(pred_dict["annotation"]) > 0
def test_process_spectrogram_with_default_model():
"""Test processing spectrogram with model."""
audio = api.load_audio(TEST_DATA[0])
spectrogram = api.generate_spectrogram(audio)
predictions, features = api.process_spectrogram(spectrogram)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
def test_process_audio_with_default_model():
"""Test processing audio with model."""
audio = api.load_audio(TEST_DATA[0])
predictions, features, spec = api.process_audio(audio)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
assert spec is not None
assert isinstance(spec, torch.Tensor)
assert spec.shape == (1, 1, 128, 512)
def test_postprocess_model_outputs():
"""Test postprocessing model outputs."""
# Load model outputs
audio = api.load_audio(TEST_DATA[1])
spec = api.generate_spectrogram(audio)
model_outputs = api.model(spec)
# Postprocess outputs
predictions, features = api.postprocess(model_outputs)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, np.ndarray)
assert features.shape[0] == len(predictions)
assert features.shape[1] == 32

41
tests/test_cli.py Normal file
View File

@ -0,0 +1,41 @@
"""Test the command line interface."""
from click.testing import CliRunner
from bat_detect.cli import cli
def test_cli_base_command():
runner = CliRunner()
result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
def test_cli_detect_command_help():
runner = CliRunner()
result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
def test_cli_detect_command_on_test_audio(tmp_path):
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3