mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
commit
7c80441d60
4
.git-blame-ignore-revs
Normal file
4
.git-blame-ignore-revs
Normal file
@ -0,0 +1,4 @@
|
||||
# Format code with Black and isort
|
||||
3c17a2337166245de8df778fe174aad997e14e8f
|
||||
9cb6b20949c7c31ee21ed2b800e8b691f1be32a7
|
||||
53100f51e083cf4d900ed325ae0543cc754a26cc
|
104
.gitignore
vendored
104
.gitignore
vendored
@ -1,10 +1,108 @@
|
||||
*.pyc
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject/
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Model artifacts
|
||||
*.png
|
||||
*.jpg
|
||||
*.wav
|
||||
*.tar
|
||||
*.json
|
||||
*.ipynb_checkpoints/
|
||||
experiments/*
|
||||
plots/*
|
||||
|
||||
# Batdetect Models [Include]
|
||||
!bat_detect/models/*.pth.tar
|
||||
|
||||
# Model experiments
|
||||
experiments/*
|
||||
|
||||
# Jupiter notebooks
|
||||
.virtual_documents
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
!batdetect2_notebook.ipynb
|
||||
|
5
.pylintrc
Normal file
5
.pylintrc
Normal file
@ -0,0 +1,5 @@
|
||||
[TYPECHECK]
|
||||
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=torch.*
|
129
README.md
129
README.md
@ -1,61 +1,122 @@
|
||||
# BatDetect2
|
||||
<img align="left" width="64" height="64" src="ims/bat_icon.png">
|
||||
<img style="display: block-inline;" width="64" height="64" src="ims/bat_icon.png"> Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
|
||||
|
||||
Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
|
||||
## Getting started
|
||||
### Python Environment
|
||||
|
||||
We recommend using an isolated Python environment to avoid dependency issues. Choose one
|
||||
of the following options:
|
||||
|
||||
* Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it:
|
||||
|
||||
```bash
|
||||
conda create -y --name batdetect2 python==3.10
|
||||
conda activate batdetect2
|
||||
```
|
||||
|
||||
* If you already have Python installed (version >= 3.8,< 3.11) and prefer using virtual environments then:
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
### Installing BatDetect2
|
||||
You can use pip to install `batdetect2`:
|
||||
|
||||
```bash
|
||||
pip install batdetect2
|
||||
```
|
||||
|
||||
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
|
||||
Once unziped, run this from extracted folder.
|
||||
|
||||
```bash
|
||||
pip install .
|
||||
```
|
||||
|
||||
Make sure you have the environment activated before installing `batdetect2`.
|
||||
|
||||
|
||||
### Getting started
|
||||
1) Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads).
|
||||
2) Download this code from the repository (by clicking on the green button on top right) and unzip it.
|
||||
3) Create a new environment and install the required packages:
|
||||
`conda env create -f environment.yml`
|
||||
`conda activate batdetect2`
|
||||
## Try the model
|
||||
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
||||
|
||||
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
|
||||
|
||||
|
||||
### Try the model
|
||||
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
||||
## Running the model on your own data
|
||||
|
||||
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
|
||||
After following the above steps to install the code you can run the model on your own data.
|
||||
|
||||
|
||||
### Running the model on your own data
|
||||
After following the above steps to install the code you can run the model on your own data by opening the command line where the code is located and typing:
|
||||
`python run_batdetect.py AUDIO_DIR ANN_DIR DETECTION_THRESHOLD`
|
||||
e.g.
|
||||
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3`
|
||||
### Using the command line
|
||||
|
||||
You can run the model by opening the command line and typing:
|
||||
```bash
|
||||
batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
|
||||
```
|
||||
e.g.
|
||||
```bash
|
||||
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
|
||||
```
|
||||
|
||||
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
|
||||
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
|
||||
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
|
||||
|
||||
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
|
||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features`
|
||||
|
||||
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
|
||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
|
||||
|
||||
|
||||
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
|
||||
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
|
||||
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
|
||||
### Using the Python API
|
||||
|
||||
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
|
||||
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --spec_features`
|
||||
If you prefer to process your data within a Python script then you can use the `batdetect2` Python API.
|
||||
|
||||
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
|
||||
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
|
||||
```python
|
||||
from batdetect2 import api
|
||||
|
||||
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
|
||||
|
||||
# Process a whole file
|
||||
results = api.process_file(AUDIO_FILE)
|
||||
|
||||
# Or, load audio and compute spectrograms
|
||||
audio = api.load_audio(AUDIO_FILE)
|
||||
spec = api.generate_spectrogram(audio)
|
||||
|
||||
# And process the audio or the spectrogram with the model
|
||||
detections, features, spec = api.process_audio(audio)
|
||||
detections, features = api.process_spectrogram(spec)
|
||||
|
||||
# Do something else ...
|
||||
```
|
||||
|
||||
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
||||
|
||||
|
||||
### Training the model on your own data
|
||||
Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model.
|
||||
## Training the model on your own data
|
||||
Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model.
|
||||
|
||||
|
||||
### Data and annotations
|
||||
The raw audio data and annotations used to train the models in the paper will be added soon.
|
||||
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
|
||||
## Data and annotations
|
||||
The raw audio data and annotations used to train the models in the paper will be added soon.
|
||||
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
|
||||
|
||||
|
||||
### Warning
|
||||
## Warning
|
||||
The models developed and shared as part of this repository should be used with caution.
|
||||
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
|
||||
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
|
||||
|
||||
|
||||
### FAQ
|
||||
For more information please consult our [FAQ](faq.md).
|
||||
## FAQ
|
||||
For more information please consult our [FAQ](faq.md).
|
||||
|
||||
|
||||
### Reference
|
||||
## Reference
|
||||
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
||||
```
|
||||
@article{batdetect2_2022,
|
||||
@ -66,8 +127,8 @@ If you find our work useful in your research please consider citing our paper wh
|
||||
}
|
||||
```
|
||||
|
||||
### Acknowledgements
|
||||
Thanks to all the contributors who spent time collecting and annotating audio data.
|
||||
## Acknowledgements
|
||||
Thanks to all the contributors who spent time collecting and annotating audio data.
|
||||
|
||||
|
||||
### TODOs
|
||||
|
170
app.py
170
app.py
@ -1,84 +1,126 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import bat_detect.utils.detector_utils as du
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.utils.detector_utils as du
|
||||
import bat_detect.utils.plot_utils as viz
|
||||
|
||||
|
||||
# setup the arguments
|
||||
args = {}
|
||||
args = du.get_default_bd_args()
|
||||
args['detection_threshold'] = 0.3
|
||||
args['time_expansion_factor'] = 1
|
||||
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
|
||||
args = du.get_default_run_config()
|
||||
args["detection_threshold"] = 0.3
|
||||
args["time_expansion_factor"] = 1
|
||||
args["model_path"] = "models/Net2DFast_UK_same.pth.tar"
|
||||
max_duration = 2.0
|
||||
|
||||
# load the model
|
||||
model, params = du.load_model(args['model_path'])
|
||||
model, params = du.load_model(args["model_path"])
|
||||
|
||||
|
||||
df = gr.Dataframe(
|
||||
headers=["species", "time", "detection_prob", "species_prob"],
|
||||
datatype=["str", "str", "str", "str"],
|
||||
row_count=1,
|
||||
col_count=(4, "fixed"),
|
||||
label='Predictions'
|
||||
)
|
||||
|
||||
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
|
||||
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
|
||||
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
|
||||
df = gr.Dataframe(
|
||||
headers=["species", "time", "detection_prob", "species_prob"],
|
||||
datatype=["str", "str", "str", "str"],
|
||||
row_count=1,
|
||||
col_count=(4, "fixed"),
|
||||
label="Predictions",
|
||||
)
|
||||
|
||||
examples = [
|
||||
["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3],
|
||||
["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3],
|
||||
["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3],
|
||||
]
|
||||
|
||||
|
||||
def make_prediction(file_name=None, detection_threshold=0.3):
|
||||
|
||||
if file_name is not None:
|
||||
audio_file = file_name
|
||||
else:
|
||||
return "You must provide an input audio file."
|
||||
|
||||
if detection_threshold is not None and detection_threshold != '':
|
||||
args['detection_threshold'] = float(detection_threshold)
|
||||
|
||||
|
||||
if detection_threshold is not None and detection_threshold != "":
|
||||
args["detection_threshold"] = float(detection_threshold)
|
||||
|
||||
run_config = {
|
||||
**params,
|
||||
**args,
|
||||
"max_duration": max_duration,
|
||||
}
|
||||
|
||||
# process the file to generate predictions
|
||||
results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
|
||||
|
||||
anns = [ann for ann in results['pred_dict']['annotation']]
|
||||
clss = [aa['class'] for aa in anns]
|
||||
st_time = [aa['start_time'] for aa in anns]
|
||||
cls_prob = [aa['class_prob'] for aa in anns]
|
||||
det_prob = [aa['det_prob'] for aa in anns]
|
||||
data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
|
||||
|
||||
results = du.process_file(
|
||||
audio_file,
|
||||
model,
|
||||
run_config,
|
||||
)
|
||||
|
||||
anns = [ann for ann in results["pred_dict"]["annotation"]]
|
||||
clss = [aa["class"] for aa in anns]
|
||||
st_time = [aa["start_time"] for aa in anns]
|
||||
cls_prob = [aa["class_prob"] for aa in anns]
|
||||
det_prob = [aa["det_prob"] for aa in anns]
|
||||
data = {
|
||||
"species": clss,
|
||||
"time": st_time,
|
||||
"detection_prob": det_prob,
|
||||
"species_prob": cls_prob,
|
||||
}
|
||||
|
||||
df = pd.DataFrame(data=data)
|
||||
im = generate_results_image(audio_file, anns)
|
||||
|
||||
|
||||
return [df, im]
|
||||
|
||||
|
||||
def generate_results_image(audio_file, anns):
|
||||
|
||||
def generate_results_image(audio_file, anns):
|
||||
|
||||
# load audio
|
||||
sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'],
|
||||
params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
|
||||
sampling_rate, audio = au.load_audio(
|
||||
audio_file,
|
||||
args["time_expansion_factor"],
|
||||
params["target_samp_rate"],
|
||||
params["scale_raw_audio"],
|
||||
max_duration=max_duration,
|
||||
)
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
|
||||
|
||||
# generate spec
|
||||
spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)
|
||||
spec, spec_viz = au.generate_spectrogram(
|
||||
audio, sampling_rate, params, True, False
|
||||
)
|
||||
|
||||
# create fig
|
||||
plt.close('all')
|
||||
fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
|
||||
spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
||||
viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
|
||||
plt.ylabel('Freq - kHz')
|
||||
plt.xlabel('Time - secs')
|
||||
plt.close("all")
|
||||
fig = plt.figure(
|
||||
1,
|
||||
figsize=(spec.shape[1] / 100, spec.shape[0] / 100),
|
||||
dpi=100,
|
||||
frameon=False,
|
||||
)
|
||||
spec_duration = au.x_coords_to_time(
|
||||
spec.shape[1],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
viz.create_box_image(
|
||||
spec,
|
||||
fig,
|
||||
anns,
|
||||
0,
|
||||
spec_duration,
|
||||
spec_duration,
|
||||
params,
|
||||
spec.max() * 1.1,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
plt.ylabel("Freq - kHz")
|
||||
plt.xlabel("Time - secs")
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
# convert fig to image
|
||||
fig.canvas.draw()
|
||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
@ -88,21 +130,23 @@ def generate_results_image(audio_file, anns):
|
||||
return im
|
||||
|
||||
|
||||
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
|
||||
"<br>This model is only trained on bat species from the UK. If the input " \
|
||||
"file is longer than 2 seconds, only the first 2 seconds will be processed." \
|
||||
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
|
||||
descr_txt = (
|
||||
"Demo of BatDetect2 deep learning-based bat echolocation call detection. "
|
||||
"<br>This model is only trained on bat species from the UK. If the input "
|
||||
"file is longer than 2 seconds, only the first 2 seconds will be processed."
|
||||
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
|
||||
)
|
||||
|
||||
gr.Interface(
|
||||
fn = make_prediction,
|
||||
inputs = [gr.Audio(source="upload", type="filepath", optional=True),
|
||||
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
|
||||
outputs = [df, gr.Image(label="Visualisation")],
|
||||
theme = "huggingface",
|
||||
title = "BatDetect2 Demo",
|
||||
description = descr_txt,
|
||||
examples = examples,
|
||||
allow_flagging = 'never',
|
||||
fn=make_prediction,
|
||||
inputs=[
|
||||
gr.Audio(source="upload", type="filepath", optional=True),
|
||||
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
|
||||
],
|
||||
outputs=[df, gr.Image(label="Visualisation")],
|
||||
theme="huggingface",
|
||||
title="BatDetect2 Demo",
|
||||
description=descr_txt,
|
||||
examples=examples,
|
||||
allow_flagging="never",
|
||||
).launch()
|
||||
|
||||
|
||||
|
397
bat_detect/api.py
Normal file
397
bat_detect/api.py
Normal file
@ -0,0 +1,397 @@
|
||||
"""Python API for bat_detect.
|
||||
|
||||
This module provides a Python API for bat_detect. It can be used to
|
||||
process audio files or spectrograms with the default model or a custom
|
||||
model.
|
||||
|
||||
Example
|
||||
-------
|
||||
You can use the default model to process audio files. To process a single
|
||||
file, use the `process_file` function.
|
||||
>>> import bat_detect.api as api
|
||||
>>> # Process audio file
|
||||
>>> results = api.process_file("audio_file.wav")
|
||||
|
||||
To process multiple files, use the `list_audio_files` function to get a list
|
||||
of audio files in a directory. Then use the `process_file` function to
|
||||
process each file.
|
||||
|
||||
>>> import bat_detect.api as api
|
||||
>>> # Get list of audio files
|
||||
>>> audio_files = api.list_audio_files("audio_directory")
|
||||
>>> # Process audio files
|
||||
>>> results = [api.process_file(f) for f in audio_files]
|
||||
|
||||
The `process_file` function will slice the recording into 3 second chunks
|
||||
and process each chunk separately, in case the recording is longer. The
|
||||
results will be combined into a dictionary with the following keys:
|
||||
|
||||
- `pred_dict`: All the predictions from the model in the format
|
||||
expected by the annotation tool.
|
||||
- `cnn_feats`: Optional. A list of `numpy` arrays containing the CNN features
|
||||
for each detection. The CNN features are the output of the CNN before
|
||||
the final classification layer. You can use these features to train
|
||||
your own classifier, or to do other processing on the detections.
|
||||
They are in the same order as the detections in
|
||||
`results['pred_dict']['annotation']`. Will only be returned if the
|
||||
`cnn_feats` parameter in the config is set to `True`.
|
||||
- `spec_slices`: Optional. A list of `numpy` arrays containing the spectrogram
|
||||
for each of the processed chunks. Will only be returned if the
|
||||
`spec_slices` parameter in the config is set to `True`.
|
||||
|
||||
Alternatively, you can use the `process_audio` function to process an audio
|
||||
array directly, or `process_spectrogram` to process spectrograms. This
|
||||
allows you to do other preprocessing steps before running the model for
|
||||
predictions.
|
||||
|
||||
>>> import bat_detect.api as api
|
||||
>>> # Load audio
|
||||
>>> audio = api.load_audio("audio_file.wav")
|
||||
>>> # Process the audio array
|
||||
>>> detections, features, spec = api.process_audio(audio)
|
||||
>>> # Or compute and process the spectrogram
|
||||
>>> spec = api.generate_spectrogram(audio)
|
||||
>>> detections, features = api.process_spectrogram(spec)
|
||||
|
||||
Here `detections` is the list of detected calls, `features` is the list of
|
||||
CNN features for each detection, and `spec` is the spectrogram of the
|
||||
processed audio. Each detection is a dictionary similary to the
|
||||
following:
|
||||
|
||||
{
|
||||
'start_time': 0.0,
|
||||
'end_time': 0.1,
|
||||
'low_freq': 10000,
|
||||
'high_freq': 20000,
|
||||
'class': 'Myotis myotis',
|
||||
'class_prob': 0.9,
|
||||
'det_prob': 0.9,
|
||||
'individual': 0,
|
||||
'event': 'Echolocation'
|
||||
}
|
||||
|
||||
If you wish to interact directly with the model, you can use the `model`
|
||||
attribute to get the default model.
|
||||
|
||||
>>> import bat_detect.api as api
|
||||
>>> # Get the default model
|
||||
>>> model = api.model
|
||||
>>> # Process the spectrogram
|
||||
>>> outputs = model(spec)
|
||||
|
||||
However, you will need to do the postprocessing yourself. The
|
||||
model outputs are a collection of raw tensors. The `postprocess`
|
||||
function can be used to convert the model outputs into a list of
|
||||
detections and a list of CNN features.
|
||||
|
||||
>>> import bat_detect.api as api
|
||||
>>> # Get the default model
|
||||
>>> model = api.model
|
||||
>>> # Process the spectrogram
|
||||
>>> outputs = model(spec)
|
||||
>>> # Postprocess the outputs
|
||||
>>> detections, features = api.postprocess(outputs)
|
||||
|
||||
If you wish to use a custom model or change the default parameters, please
|
||||
consult the API documentation in the code.
|
||||
|
||||
"""
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.utils.detector_utils as du
|
||||
from bat_detect.detector.parameters import (
|
||||
DEFAULT_MODEL_PATH,
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS,
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
)
|
||||
from bat_detect.types import (
|
||||
Annotation,
|
||||
DetectionModel,
|
||||
ModelOutput,
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
)
|
||||
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||
|
||||
# Remove warnings from torch
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||
|
||||
__all__ = [
|
||||
"config",
|
||||
"generate_spectrogram",
|
||||
"get_config",
|
||||
"list_audio_files",
|
||||
"load_audio",
|
||||
"load_model",
|
||||
"model",
|
||||
"postprocess",
|
||||
"process_audio",
|
||||
"process_file",
|
||||
"process_spectrogram",
|
||||
]
|
||||
|
||||
|
||||
# Use GPU if available
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
# Default model
|
||||
MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE)
|
||||
|
||||
|
||||
def get_config(**kwargs) -> ProcessingConfiguration:
|
||||
"""Get default processing configuration.
|
||||
|
||||
Can be used to override default parameters by passing keyword arguments.
|
||||
"""
|
||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
||||
|
||||
|
||||
# Default processing configuration
|
||||
CONFIG = get_config(**PARAMS)
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
time_exp_fact: float = 1,
|
||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load audio from file.
|
||||
|
||||
All audio will be resampled to the target sample rate. If the audio is
|
||||
longer than max_duration, it will be truncated to max_duration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path to audio file.
|
||||
time_exp_fact : float, optional
|
||||
Time expansion factor, by default 1
|
||||
target_samp_rate : int, optional
|
||||
Target sample rate, by default 256000
|
||||
scale : bool, optional
|
||||
Scale audio to [-1, 1], by default False
|
||||
max_duration : float, optional
|
||||
Maximum duration of audio in seconds, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio data.
|
||||
"""
|
||||
_, audio = au.load_audio(
|
||||
path,
|
||||
time_exp_fact,
|
||||
target_samp_rate,
|
||||
scale,
|
||||
max_duration,
|
||||
)
|
||||
return audio
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[SpectrogramParameters] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> torch.Tensor:
|
||||
"""Generate spectrogram from audio array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio data.
|
||||
samp_rate : int, optional
|
||||
Sample rate. Defaults to 256000 which is the target sample rate of
|
||||
the default model. Only change if you loaded the audio with a
|
||||
different sample rate.
|
||||
config : SpectrogramParameters, optional
|
||||
Spectrogram parameters, by default None (uses default parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram.
|
||||
"""
|
||||
if config is None:
|
||||
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
_, spec, _ = du.compute_spectrogram(
|
||||
audio,
|
||||
samp_rate,
|
||||
config,
|
||||
return_np=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : str
|
||||
Path to audio file.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
return du.process_file(
|
||||
audio_file,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Spectrogram.
|
||||
samp_rate : int, optional
|
||||
Sample rate of the audio from which the spectrogram was generated.
|
||||
Defaults to 256000 which is the target sample rate of the default
|
||||
model. Only change if you generated the spectrogram with a different
|
||||
sample rate.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionResult
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
return du.process_spectrogram(
|
||||
spec,
|
||||
samp_rate,
|
||||
model,
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
def process_audio(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio data.
|
||||
samp_rate : int, optional
|
||||
Sample rate, by default 256000. Only change if you loaded the audio
|
||||
with a different sample rate.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of predicted annotations.
|
||||
|
||||
features: List[np.ndarray]
|
||||
List of extracted features for each annotation.
|
||||
|
||||
spec : torch.Tensor
|
||||
Spectrogram of the audio used for prediction.
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
return du.process_audio_array(
|
||||
audio,
|
||||
samp_rate,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
def postprocess(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Postprocess model outputs.
|
||||
|
||||
Convert model tensor outputs to predicted bounding boxes and
|
||||
extracted features.
|
||||
|
||||
Will run non-maximum suppression and remove overlapping annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outputs : ModelOutput
|
||||
Model raw outputs.
|
||||
samp_rate : int, Optional
|
||||
Sample rate of the audio from which the spectrogram was generated.
|
||||
Defaults to 256000 which is the target sample rate of the default
|
||||
model. Only change if you generated outputs from a spectrogram with
|
||||
sample rate.
|
||||
config : Optional[ProcessingConfiguration], Optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of predicted annotations.
|
||||
features: np.ndarray
|
||||
An array of extracted features for each annotation. The shape of the
|
||||
array is (n_annotations, n_features).
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
return du.postprocess_model_outputs(
|
||||
outputs,
|
||||
samp_rate,
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
model: DetectionModel = MODEL
|
||||
"""Base detection model."""
|
||||
|
||||
config: ProcessingConfiguration = CONFIG
|
||||
"""Default processing configuration."""
|
137
bat_detect/cli.py
Normal file
137
bat_detect/cli.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from bat_detect import api
|
||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from bat_detect.utils.detector_utils import save_results_to_file
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"audio_dir",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.argument(
|
||||
"ann_dir",
|
||||
type=click.Path(exists=False),
|
||||
)
|
||||
@click.argument(
|
||||
"detection_threshold",
|
||||
type=float,
|
||||
)
|
||||
@click.option(
|
||||
"--cnn_features",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Extracts CNN call features",
|
||||
)
|
||||
@click.option(
|
||||
"--spec_features",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Extracts low level call features",
|
||||
)
|
||||
@click.option(
|
||||
"--time_expansion_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The time expansion factor used for all files (default is 1)",
|
||||
)
|
||||
@click.option(
|
||||
"--quiet",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Minimize output printing",
|
||||
)
|
||||
@click.option(
|
||||
"--save_preds_if_empty",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Save empty annotation file if no detections made.",
|
||||
)
|
||||
@click.option(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=DEFAULT_MODEL_PATH,
|
||||
help="Path to trained BatDetect2 model",
|
||||
)
|
||||
def detect(
|
||||
audio_dir: str,
|
||||
ann_dir: str,
|
||||
detection_threshold: float,
|
||||
**args,
|
||||
):
|
||||
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||
|
||||
DETECTION_THRESHOLD is the detection threshold. All predictions with a
|
||||
score below this threshold will be discarded. Values between 0 and 1.
|
||||
|
||||
Assumes audio files are mono, not stereo.
|
||||
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
click.echo(f"Loading model: {args['model_path']}")
|
||||
model, params = api.load_model(args["model_path"])
|
||||
|
||||
click.echo(f"\nInput directory: {audio_dir}")
|
||||
files = api.list_audio_files(audio_dir)
|
||||
|
||||
click.echo(f"Number of audio files: {len(files)}")
|
||||
click.echo(f"\nSaving results to: {ann_dir}")
|
||||
|
||||
config = api.get_config(
|
||||
**{
|
||||
**params,
|
||||
**args,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 2,
|
||||
"detection_threshold": detection_threshold,
|
||||
}
|
||||
)
|
||||
|
||||
# process files
|
||||
error_files = []
|
||||
for audio_file in files:
|
||||
try:
|
||||
results = api.process_file(audio_file, model, config=config)
|
||||
|
||||
if args["save_preds_if_empty"] or (
|
||||
len(results["pred_dict"]["annotation"]) > 0
|
||||
):
|
||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||
save_results_to_file(results, results_path)
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
error_files.append(audio_file)
|
||||
click.secho(f"Error processing file!: {err}", fg="red")
|
||||
raise err
|
||||
|
||||
click.echo(f"\nResults saved to: {ann_dir}")
|
||||
|
||||
if len(error_files) > 0:
|
||||
click.secho("\nUnable to process the follow files:", fg="red")
|
||||
for err in error_files:
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
@ -2,8 +2,10 @@ import numpy as np
|
||||
|
||||
|
||||
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
||||
spec_ind = spec_height-spec_ind
|
||||
return round((spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2)
|
||||
spec_ind = spec_height - spec_ind
|
||||
return round(
|
||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||
)
|
||||
|
||||
|
||||
def extract_spec_slices(spec, pred_nms, params):
|
||||
@ -11,28 +13,40 @@ def extract_spec_slices(spec, pred_nms, params):
|
||||
Extracts spectrogram slices from spectrogram based on detected call locations.
|
||||
"""
|
||||
|
||||
x_pos = pred_nms['x_pos']
|
||||
y_pos = pred_nms['y_pos']
|
||||
bb_width = pred_nms['bb_width']
|
||||
bb_height = pred_nms['bb_height']
|
||||
slices = []
|
||||
x_pos = pred_nms["x_pos"]
|
||||
y_pos = pred_nms["y_pos"]
|
||||
bb_width = pred_nms["bb_width"]
|
||||
bb_height = pred_nms["bb_height"]
|
||||
slices = []
|
||||
|
||||
# add 20% padding either side of call
|
||||
pad = bb_width*0.2
|
||||
x_pos_pad = x_pos - pad
|
||||
bb_width_pad = bb_width + 2*pad
|
||||
pad = bb_width * 0.2
|
||||
x_pos_pad = x_pos - pad
|
||||
bb_width_pad = bb_width + 2 * pad
|
||||
|
||||
for ff in range(len(pred_nms['det_probs'])):
|
||||
for ff in range(len(pred_nms["det_probs"])):
|
||||
x_start = int(np.maximum(0, x_pos_pad[ff]))
|
||||
x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos_pad[ff] + bb_width_pad[ff])))
|
||||
x_end = int(
|
||||
np.minimum(
|
||||
spec.shape[1] - 1, np.round(x_pos_pad[ff] + bb_width_pad[ff])
|
||||
)
|
||||
)
|
||||
slices.append(spec[:, x_start:x_end].astype(np.float16))
|
||||
return slices
|
||||
|
||||
|
||||
def get_feature_names():
|
||||
feature_names = ['duration', 'low_freq_bb', 'high_freq_bb', 'bandwidth',
|
||||
'max_power_bb', 'max_power', 'max_power_first',
|
||||
'max_power_second', 'call_interval']
|
||||
feature_names = [
|
||||
"duration",
|
||||
"low_freq_bb",
|
||||
"high_freq_bb",
|
||||
"bandwidth",
|
||||
"max_power_bb",
|
||||
"max_power",
|
||||
"max_power_first",
|
||||
"max_power_second",
|
||||
"call_interval",
|
||||
]
|
||||
return feature_names
|
||||
|
||||
|
||||
@ -45,40 +59,76 @@ def get_feats(spec, pred_nms, params):
|
||||
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
|
||||
"""
|
||||
|
||||
x_pos = pred_nms['x_pos']
|
||||
y_pos = pred_nms['y_pos']
|
||||
bb_width = pred_nms['bb_width']
|
||||
bb_height = pred_nms['bb_height']
|
||||
x_pos = pred_nms["x_pos"]
|
||||
y_pos = pred_nms["y_pos"]
|
||||
bb_width = pred_nms["bb_width"]
|
||||
bb_height = pred_nms["bb_height"]
|
||||
|
||||
feature_names = get_feature_names()
|
||||
num_detections = len(pred_nms['det_probs'])
|
||||
features = np.ones((num_detections, len(feature_names)), dtype=np.float32)*-1
|
||||
feature_names = get_feature_names()
|
||||
num_detections = len(pred_nms["det_probs"])
|
||||
features = (
|
||||
np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1
|
||||
)
|
||||
|
||||
for ff in range(num_detections):
|
||||
x_start = int(np.maximum(0, x_pos[ff]))
|
||||
x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos[ff] + bb_width[ff])))
|
||||
x_end = int(
|
||||
np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff]))
|
||||
)
|
||||
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top
|
||||
y_low = int(np.minimum(spec.shape[0]-1, y_pos[ff]))
|
||||
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
|
||||
y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff]))
|
||||
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
|
||||
spec_slice = spec[:, x_start:x_end]
|
||||
|
||||
if spec_slice.shape[1] > 1:
|
||||
features[ff, 0] = round(pred_nms['end_times'][ff] - pred_nms['start_times'][ff], 5)
|
||||
features[ff, 1] = int(pred_nms['low_freqs'][ff])
|
||||
features[ff, 2] = int(pred_nms['high_freqs'][ff])
|
||||
features[ff, 3] = int(pred_nms['high_freqs'][ff] - pred_nms['low_freqs'][ff])
|
||||
features[ff, 4] = int(convert_int_to_freq(y_high+spec_slice[y_high:y_low, :].sum(1).argmax(),
|
||||
spec.shape[0], params['min_freq'], params['max_freq']))
|
||||
features[ff, 5] = int(convert_int_to_freq(spec_slice.sum(1).argmax(),
|
||||
spec.shape[0], params['min_freq'], params['max_freq']))
|
||||
hlf_val = spec_slice.shape[1]//2
|
||||
features[ff, 0] = round(
|
||||
pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5
|
||||
)
|
||||
features[ff, 1] = int(pred_nms["low_freqs"][ff])
|
||||
features[ff, 2] = int(pred_nms["high_freqs"][ff])
|
||||
features[ff, 3] = int(
|
||||
pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
|
||||
)
|
||||
features[ff, 4] = int(
|
||||
convert_int_to_freq(
|
||||
y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
)
|
||||
features[ff, 5] = int(
|
||||
convert_int_to_freq(
|
||||
spec_slice.sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
)
|
||||
hlf_val = spec_slice.shape[1] // 2
|
||||
|
||||
features[ff, 6] = int(convert_int_to_freq(spec_slice[:, :hlf_val].sum(1).argmax(),
|
||||
spec.shape[0], params['min_freq'], params['max_freq']))
|
||||
features[ff, 7] = int(convert_int_to_freq(spec_slice[:, hlf_val:].sum(1).argmax(),
|
||||
spec.shape[0], params['min_freq'], params['max_freq']))
|
||||
features[ff, 6] = int(
|
||||
convert_int_to_freq(
|
||||
spec_slice[:, :hlf_val].sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
)
|
||||
features[ff, 7] = int(
|
||||
convert_int_to_freq(
|
||||
spec_slice[:, hlf_val:].sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
)
|
||||
|
||||
if ff > 0:
|
||||
features[ff, 8] = round(pred_nms['start_times'][ff] - pred_nms['start_times'][ff-1], 5)
|
||||
features[ff, 8] = round(
|
||||
pred_nms["start_times"][ff]
|
||||
- pred_nms["start_times"][ff - 1],
|
||||
5,
|
||||
)
|
||||
|
||||
return features
|
||||
|
@ -1,8 +1,14 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
"SelfAttention",
|
||||
"ConvBlockDownCoordF",
|
||||
"ConvBlockDownStandard",
|
||||
"ConvBlockUpF",
|
||||
"ConvBlockUpStandard",
|
||||
]
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@ -10,38 +16,61 @@ class SelfAttention(nn.Module):
|
||||
super(SelfAttention, self).__init__()
|
||||
# Note, does not encode position information (absolute or realtive)
|
||||
self.temperature = 1.0
|
||||
self.att_dim = att_dim
|
||||
self.att_dim = att_dim
|
||||
self.key_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.val_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.que_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.squeeze(2).permute(0,2,1)
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
|
||||
kk = torch.matmul(x, self.key_fun.weight.T) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
qq = torch.matmul(x, self.que_fun.weight.T) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
vv = torch.matmul(x, self.val_fun.weight.T) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
kk = torch.matmul(
|
||||
x, self.key_fun.weight.T
|
||||
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
qq = torch.matmul(
|
||||
x, self.que_fun.weight.T
|
||||
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
vv = torch.matmul(
|
||||
x, self.val_fun.weight.T
|
||||
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
kk_qq = torch.bmm(kk, qq.permute(0,2,1)) / (self.temperature*self.att_dim)
|
||||
att_weights = F.softmax(kk_qq, 1) # each col of each attention matrix sums to 1
|
||||
att = torch.bmm(vv.permute(0,2,1), att_weights)
|
||||
kk_qq = torch.bmm(kk, qq.permute(0, 2, 1)) / (
|
||||
self.temperature * self.att_dim
|
||||
)
|
||||
att_weights = F.softmax(
|
||||
kk_qq, 1
|
||||
) # each col of each attention matrix sums to 1
|
||||
att = torch.bmm(vv.permute(0, 2, 1), att_weights)
|
||||
|
||||
op = torch.matmul(att.permute(0,2,1), self.pro_fun.weight.T) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
op = op.permute(0,2,1).unsqueeze(2)
|
||||
op = torch.matmul(
|
||||
att.permute(0, 2, 1), self.pro_fun.weight.T
|
||||
) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
op = op.permute(0, 2, 1).unsqueeze(2)
|
||||
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1):
|
||||
def __init__(
|
||||
self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
|
||||
):
|
||||
super(ConvBlockDownCoordF, self).__init__()
|
||||
self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height)[None, None, ..., None], requires_grad=False)
|
||||
self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size, stride=stride)
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn + 1,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x):
|
||||
freq_info = self.coords.repeat(x.shape[0],1,1,x.shape[3])
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
x = torch.cat((x, freq_info), 1)
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
@ -49,9 +78,17 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
|
||||
|
||||
class ConvBlockDownStandard(nn.Module):
|
||||
def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1):
|
||||
def __init__(
|
||||
self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size, stride=stride)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x):
|
||||
@ -61,17 +98,41 @@ class ConvBlockDownStandard(nn.Module):
|
||||
|
||||
|
||||
class ConvBlockUpF(nn.Module):
|
||||
def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)):
|
||||
def __init__(
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
ip_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
up_mode="bilinear",
|
||||
up_scale=(2, 2),
|
||||
):
|
||||
super(ConvBlockUpF, self).__init__()
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height*up_scale[0])[None, None, ..., None], requires_grad=False)
|
||||
self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size)
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, ip_height * up_scale[0])[
|
||||
None, None, ..., None
|
||||
],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x):
|
||||
op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False)
|
||||
freq_info = self.coords.repeat(op.shape[0],1,1,op.shape[3])
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
x.shape[-2] * self.up_scale[0],
|
||||
x.shape[-1] * self.up_scale[1],
|
||||
),
|
||||
mode=self.up_mode,
|
||||
align_corners=False,
|
||||
)
|
||||
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
||||
op = torch.cat((op, freq_info), 1)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
@ -79,15 +140,34 @@ class ConvBlockUpF(nn.Module):
|
||||
|
||||
|
||||
class ConvBlockUpStandard(nn.Module):
|
||||
def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)):
|
||||
def __init__(
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
ip_height=None,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
up_mode="bilinear",
|
||||
up_scale=(2, 2),
|
||||
):
|
||||
super(ConvBlockUpStandard, self).__init__()
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn, out_chn, kernel_size=k_size, padding=pad_size
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x):
|
||||
op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False)
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
x.shape[-2] * self.up_scale[0],
|
||||
x.shape[-1] * self.up_scale[1],
|
||||
),
|
||||
mode=self.up_mode,
|
||||
align_corners=False,
|
||||
)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
return op
|
||||
|
@ -1,54 +1,109 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .model_helpers import *
|
||||
|
||||
import torchvision
|
||||
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.detector.model_helpers import (
|
||||
ConvBlockDownCoordF,
|
||||
ConvBlockDownStandard,
|
||||
ConvBlockUpF,
|
||||
ConvBlockUpStandard,
|
||||
SelfAttention,
|
||||
)
|
||||
from bat_detect.types import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
]
|
||||
|
||||
|
||||
class Net2DFast(nn.Module):
|
||||
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
|
||||
super(Net2DFast, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
num_filts,
|
||||
num_classes=0,
|
||||
emb_dim=0,
|
||||
ip_height=128,
|
||||
resize_factor=0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.emb_dim = emb_dim
|
||||
self.num_filts = num_filts
|
||||
self.resize_factor = resize_factor
|
||||
self.ip_height_rs = ip_height
|
||||
self.bneck_height = self.ip_height_rs//32
|
||||
self.bneck_height = self.ip_height_rs // 32
|
||||
|
||||
# encoder
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
num_filts // 4,
|
||||
self.ip_height_rs,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
num_filts // 4,
|
||||
num_filts // 2,
|
||||
self.ip_height_rs // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
num_filts // 2,
|
||||
num_filts,
|
||||
self.ip_height_rs // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
|
||||
|
||||
# bottleneck
|
||||
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
|
||||
self.att = SelfAttention(num_filts*2, num_filts*2)
|
||||
self.conv_1d = nn.Conv2d(
|
||||
num_filts * 2,
|
||||
num_filts * 2,
|
||||
(self.ip_height_rs // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
|
||||
self.att = SelfAttention(num_filts * 2, num_filts * 2)
|
||||
|
||||
# decoder
|
||||
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
|
||||
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
|
||||
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
num_filts * 2, num_filts // 2, self.ip_height_rs // 8
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
num_filts // 2, num_filts // 4, self.ip_height_rs // 4
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
num_filts // 4, num_filts // 4, self.ip_height_rs // 2
|
||||
)
|
||||
|
||||
# output
|
||||
# +1 to include background class for class output
|
||||
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
|
||||
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
|
||||
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
|
||||
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
|
||||
self.conv_op = nn.Conv2d(
|
||||
num_filts // 4, num_filts // 4, kernel_size=3, padding=1
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
|
||||
self.conv_size_op = nn.Conv2d(
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
|
||||
self.conv_emb = nn.Conv2d(
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
|
||||
def forward(self, ip, return_feats=False):
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
|
||||
# encoder
|
||||
x1 = self.conv_dn_0(ip)
|
||||
@ -59,134 +114,218 @@ class Net2DFast(nn.Module):
|
||||
# bottleneck
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = self.att(x)
|
||||
x = x.repeat([1,1,self.bneck_height*4,1])
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
# decoder
|
||||
x = self.conv_up_2(x+x3)
|
||||
x = self.conv_up_3(x+x2)
|
||||
x = self.conv_up_4(x+x1)
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
# output
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
cls = self.conv_classes_op(x)
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
op = {}
|
||||
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
|
||||
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
|
||||
op['pred_class'] = comb
|
||||
op['pred_class_un_norm'] = cls
|
||||
if self.emb_dim > 0:
|
||||
op['pred_emb'] = self.conv_emb(x)
|
||||
if return_feats:
|
||||
op['features'] = x
|
||||
|
||||
return op
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
)
|
||||
|
||||
|
||||
class Net2DFastNoAttn(nn.Module):
|
||||
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
|
||||
super(Net2DFastNoAttn, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
num_filts,
|
||||
num_classes=0,
|
||||
emb_dim=0,
|
||||
ip_height=128,
|
||||
resize_factor=0.5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.emb_dim = emb_dim
|
||||
self.num_filts = num_filts
|
||||
self.resize_factor = resize_factor
|
||||
self.ip_height_rs = ip_height
|
||||
self.bneck_height = self.ip_height_rs//32
|
||||
self.bneck_height = self.ip_height_rs // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
num_filts // 4,
|
||||
self.ip_height_rs,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
num_filts // 4,
|
||||
num_filts // 2,
|
||||
self.ip_height_rs // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
num_filts // 2,
|
||||
num_filts,
|
||||
self.ip_height_rs // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
|
||||
self.conv_1d = nn.Conv2d(
|
||||
num_filts * 2,
|
||||
num_filts * 2,
|
||||
(self.ip_height_rs // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
|
||||
|
||||
|
||||
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
|
||||
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
|
||||
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
num_filts * 2, num_filts // 2, self.ip_height_rs // 8
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
num_filts // 2, num_filts // 4, self.ip_height_rs // 4
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
num_filts // 4, num_filts // 4, self.ip_height_rs // 2
|
||||
)
|
||||
|
||||
# output
|
||||
# +1 to include background class for class output
|
||||
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
|
||||
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
|
||||
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
|
||||
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
|
||||
self.conv_op = nn.Conv2d(
|
||||
num_filts // 4, num_filts // 4, kernel_size=3, padding=1
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
|
||||
self.conv_size_op = nn.Conv2d(
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
|
||||
|
||||
def forward(self, ip, return_feats=False):
|
||||
self.conv_emb = nn.Conv2d(
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
||||
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = x.repeat([1,1,self.bneck_height*4,1])
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x+x3)
|
||||
x = self.conv_up_3(x+x2)
|
||||
x = self.conv_up_4(x+x1)
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
cls = self.conv_classes_op(x)
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
op = {}
|
||||
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
|
||||
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
|
||||
op['pred_class'] = comb
|
||||
op['pred_class_un_norm'] = cls
|
||||
if self.emb_dim > 0:
|
||||
op['pred_emb'] = self.conv_emb(x)
|
||||
if return_feats:
|
||||
op['features'] = x
|
||||
|
||||
return op
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
)
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(nn.Module):
|
||||
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
|
||||
super(Net2DFastNoCoordConv, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
num_filts,
|
||||
num_classes=0,
|
||||
emb_dim=0,
|
||||
ip_height=128,
|
||||
resize_factor=0.5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.emb_dim = emb_dim
|
||||
self.num_filts = num_filts
|
||||
self.resize_factor = resize_factor
|
||||
self.ip_height_rs = ip_height
|
||||
self.bneck_height = self.ip_height_rs//32
|
||||
self.bneck_height = self.ip_height_rs // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownStandard(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_1 = ConvBlockDownStandard(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_2 = ConvBlockDownStandard(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
|
||||
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
|
||||
self.conv_dn_0 = ConvBlockDownStandard(
|
||||
1,
|
||||
num_filts // 4,
|
||||
self.ip_height_rs,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownStandard(
|
||||
num_filts // 4,
|
||||
num_filts // 2,
|
||||
self.ip_height_rs // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownStandard(
|
||||
num_filts // 2,
|
||||
num_filts,
|
||||
self.ip_height_rs // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
|
||||
self.conv_1d = nn.Conv2d(
|
||||
num_filts * 2,
|
||||
num_filts * 2,
|
||||
(self.ip_height_rs // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
|
||||
|
||||
self.att = SelfAttention(num_filts*2, num_filts*2)
|
||||
self.att = SelfAttention(num_filts * 2, num_filts * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpStandard(num_filts*2, num_filts//2, self.ip_height_rs//8)
|
||||
self.conv_up_3 = ConvBlockUpStandard(num_filts//2, num_filts//4, self.ip_height_rs//4)
|
||||
self.conv_up_4 = ConvBlockUpStandard(num_filts//4, num_filts//4, self.ip_height_rs//2)
|
||||
self.conv_up_2 = ConvBlockUpStandard(
|
||||
num_filts * 2, num_filts // 2, self.ip_height_rs // 8
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpStandard(
|
||||
num_filts // 2, num_filts // 4, self.ip_height_rs // 4
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpStandard(
|
||||
num_filts // 4, num_filts // 4, self.ip_height_rs // 2
|
||||
)
|
||||
|
||||
# output
|
||||
# +1 to include background class for class output
|
||||
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
|
||||
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
|
||||
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
|
||||
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
|
||||
self.conv_op = nn.Conv2d(
|
||||
num_filts // 4, num_filts // 4, kernel_size=3, padding=1
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
|
||||
self.conv_size_op = nn.Conv2d(
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
|
||||
self.conv_emb = nn.Conv2d(
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False):
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
@ -195,24 +334,21 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = self.att(x)
|
||||
x = x.repeat([1,1,self.bneck_height*4,1])
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x+x3)
|
||||
x = self.conv_up_3(x+x2)
|
||||
x = self.conv_up_4(x+x1)
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
cls = self.conv_classes_op(x)
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
op = {}
|
||||
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
|
||||
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
|
||||
op['pred_class'] = comb
|
||||
op['pred_class_un_norm'] = cls
|
||||
if self.emb_dim > 0:
|
||||
op['pred_emb'] = self.conv_emb(x)
|
||||
if return_feats:
|
||||
op['features'] = x
|
||||
|
||||
return op
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||
features=x,
|
||||
)
|
||||
|
@ -1,108 +1,235 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from bat_detect.types import (
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
)
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
RESIZE_FACTOR = 0.5
|
||||
SPEC_DIVIDE_FACTOR = 32
|
||||
SPEC_HEIGHT = 256
|
||||
SCALE_RAW_AUDIO = False
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
NMS_KERNEL_SIZE = 9
|
||||
NMS_TOP_K_PER_SEC = 200
|
||||
SPEC_SCALE = "pcen"
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
||||
"detection_threshold": DETECTION_THRESHOLD,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 3,
|
||||
"spec_features": False,
|
||||
"cnn_features": False,
|
||||
"quiet": True,
|
||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||
"class_names": [],
|
||||
"time_expansion": 1,
|
||||
"top_n": 3,
|
||||
"return_raw_preds": False,
|
||||
"max_duration": None,
|
||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
def mk_dir(path):
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def get_params(make_dirs=False, exps_dir='../../experiments/'):
|
||||
|
||||
|
||||
def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
||||
params = {}
|
||||
|
||||
params['model_name'] = 'Net2DFast' # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
||||
params['num_filters'] = 128
|
||||
params[
|
||||
"model_name"
|
||||
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
||||
params["num_filters"] = 128
|
||||
|
||||
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
||||
model_name = now_str + '.pth.tar'
|
||||
params['experiment'] = os.path.join(exps_dir, now_str, '')
|
||||
params['model_file_name'] = os.path.join(params['experiment'], model_name)
|
||||
params['op_im_dir'] = os.path.join(params['experiment'], 'op_ims', '')
|
||||
params['op_im_dir_test'] = os.path.join(params['experiment'], 'op_ims_test', '')
|
||||
#params['notes'] = '' # can save notes about an experiment here
|
||||
|
||||
model_name = now_str + ".pth.tar"
|
||||
params["experiment"] = os.path.join(exps_dir, now_str, "")
|
||||
params["model_file_name"] = os.path.join(params["experiment"], model_name)
|
||||
params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
|
||||
params["op_im_dir_test"] = os.path.join(
|
||||
params["experiment"], "op_ims_test", ""
|
||||
)
|
||||
# params['notes'] = '' # can save notes about an experiment here
|
||||
|
||||
# spec parameters
|
||||
params['target_samp_rate'] = 256000 # resamples all audio so that it is at this rate
|
||||
params['fft_win_length'] = 512 / 256000.0 # in milliseconds, amount of time per stft time step
|
||||
params['fft_overlap'] = 0.75 # stft window overlap
|
||||
params[
|
||||
"target_samp_rate"
|
||||
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
|
||||
params[
|
||||
"fft_win_length"
|
||||
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
|
||||
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
|
||||
|
||||
params['max_freq'] = 120000 # in Hz, everything above this will be discarded
|
||||
params['min_freq'] = 10000 # in Hz, everything below this will be discarded
|
||||
params[
|
||||
"max_freq"
|
||||
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
|
||||
params[
|
||||
"min_freq"
|
||||
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
|
||||
|
||||
params['resize_factor'] = 0.5 # resize so the spectrogram at the input of the network
|
||||
params['spec_height'] = 256 # units are number of frequency bins (before resizing is performed)
|
||||
params['spec_train_width'] = 512 # units are number of time steps (before resizing is performed)
|
||||
params['spec_divide_factor'] = 32 # spectrogram should be divisible by this amount in width and height
|
||||
params[
|
||||
"resize_factor"
|
||||
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
|
||||
params[
|
||||
"spec_height"
|
||||
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
|
||||
params[
|
||||
"spec_train_width"
|
||||
] = 512 # units are number of time steps (before resizing is performed)
|
||||
params[
|
||||
"spec_divide_factor"
|
||||
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
|
||||
|
||||
# spec processing params
|
||||
params['denoise_spec_avg'] = True # removes the mean for each frequency band
|
||||
params['scale_raw_audio'] = False # scales the raw audio to [-1, 1]
|
||||
params['max_scale_spec'] = False # scales the spectrogram so that it is max 1
|
||||
params['spec_scale'] = 'pcen' # 'log', 'pcen', 'none'
|
||||
params[
|
||||
"denoise_spec_avg"
|
||||
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
|
||||
params[
|
||||
"scale_raw_audio"
|
||||
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
||||
params[
|
||||
"max_scale_spec"
|
||||
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
|
||||
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
|
||||
|
||||
# detection params
|
||||
params['detection_overlap'] = 0.01 # has to be within this number of ms to count as detection
|
||||
params['ignore_start_end'] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
|
||||
params['detection_threshold'] = 0.01 # the smaller this is the better the recall will be
|
||||
params['nms_kernel_size'] = 9
|
||||
params['nms_top_k_per_sec'] = 200 # keep top K highest predictions per second of audio
|
||||
params['target_sigma'] = 2.0
|
||||
params[
|
||||
"detection_overlap"
|
||||
] = 0.01 # has to be within this number of ms to count as detection
|
||||
params[
|
||||
"ignore_start_end"
|
||||
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
|
||||
params[
|
||||
"detection_threshold"
|
||||
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
|
||||
params[
|
||||
"nms_kernel_size"
|
||||
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
|
||||
params[
|
||||
"nms_top_k_per_sec"
|
||||
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
|
||||
params["target_sigma"] = 2.0
|
||||
|
||||
# augmentation params
|
||||
params['aug_prob'] = 0.20 # augmentations will be performed with this probability
|
||||
params['augment_at_train'] = True
|
||||
params['augment_at_train_combine'] = True
|
||||
params['echo_max_delay'] = 0.005 # simulate echo by adding copy of raw audio
|
||||
params['stretch_squeeze_delta'] = 0.04 # stretch or squeeze spec
|
||||
params['mask_max_time_perc'] = 0.05 # max mask size - here percentage, not ideal
|
||||
params['mask_max_freq_perc'] = 0.10 # max mask size - here percentage, not ideal
|
||||
params['spec_amp_scaling'] = 2.0 # multiply the "volume" by 0:X times current amount
|
||||
params['aug_sampling_rates'] = [220500, 256000, 300000, 312500, 384000, 441000, 500000]
|
||||
params[
|
||||
"aug_prob"
|
||||
] = 0.20 # augmentations will be performed with this probability
|
||||
params["augment_at_train"] = True
|
||||
params["augment_at_train_combine"] = True
|
||||
params[
|
||||
"echo_max_delay"
|
||||
] = 0.005 # simulate echo by adding copy of raw audio
|
||||
params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
|
||||
params[
|
||||
"mask_max_time_perc"
|
||||
] = 0.05 # max mask size - here percentage, not ideal
|
||||
params[
|
||||
"mask_max_freq_perc"
|
||||
] = 0.10 # max mask size - here percentage, not ideal
|
||||
params[
|
||||
"spec_amp_scaling"
|
||||
] = 2.0 # multiply the "volume" by 0:X times current amount
|
||||
params["aug_sampling_rates"] = [
|
||||
220500,
|
||||
256000,
|
||||
300000,
|
||||
312500,
|
||||
384000,
|
||||
441000,
|
||||
500000,
|
||||
]
|
||||
|
||||
# loss params
|
||||
params['train_loss'] = 'focal' # mse or focal
|
||||
params['det_loss_weight'] = 1.0 # weight for the detection part of the loss
|
||||
params['size_loss_weight'] = 0.1 # weight for the bbox size loss
|
||||
params['class_loss_weight'] = 2.0 # weight for the classification loss
|
||||
params['individual_loss_weight'] = 0.0 # not used
|
||||
if params['individual_loss_weight'] == 0.0:
|
||||
params['emb_dim'] = 0 # number of dimensions used for individual id embedding
|
||||
params["train_loss"] = "focal" # mse or focal
|
||||
params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
|
||||
params["size_loss_weight"] = 0.1 # weight for the bbox size loss
|
||||
params["class_loss_weight"] = 2.0 # weight for the classification loss
|
||||
params["individual_loss_weight"] = 0.0 # not used
|
||||
if params["individual_loss_weight"] == 0.0:
|
||||
params[
|
||||
"emb_dim"
|
||||
] = 0 # number of dimensions used for individual id embedding
|
||||
else:
|
||||
params['emb_dim'] = 3
|
||||
params["emb_dim"] = 3
|
||||
|
||||
# train params
|
||||
params['lr'] = 0.001
|
||||
params['batch_size'] = 8
|
||||
params['num_workers'] = 4
|
||||
params['num_epochs'] = 200
|
||||
params['num_eval_epochs'] = 5 # run evaluation every X epochs
|
||||
params['device'] = 'cuda'
|
||||
params['save_test_image_during_train'] = False
|
||||
params['save_test_image_after_train'] = True
|
||||
params["lr"] = 0.001
|
||||
params["batch_size"] = 8
|
||||
params["num_workers"] = 4
|
||||
params["num_epochs"] = 200
|
||||
params["num_eval_epochs"] = 5 # run evaluation every X epochs
|
||||
params["device"] = "cuda"
|
||||
params["save_test_image_during_train"] = False
|
||||
params["save_test_image_after_train"] = True
|
||||
|
||||
params['convert_to_genus'] = False
|
||||
params['genus_mapping'] = []
|
||||
params['class_names'] = []
|
||||
params['classes_to_ignore'] = ['', ' ', 'Unknown', 'Not Bat']
|
||||
params['generic_class'] = ['Bat']
|
||||
params['events_of_interest'] = ['Echolocation'] # will ignore all other types of events e.g. social calls
|
||||
params["convert_to_genus"] = False
|
||||
params["genus_mapping"] = []
|
||||
params["class_names"] = []
|
||||
params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
|
||||
params["generic_class"] = ["Bat"]
|
||||
params["events_of_interest"] = [
|
||||
"Echolocation"
|
||||
] # will ignore all other types of events e.g. social calls
|
||||
|
||||
# the classes in this list are standardized during training so that the same low and high freq are used
|
||||
params['standardize_classs_names'] = []
|
||||
params["standardize_classs_names"] = []
|
||||
|
||||
# create directories
|
||||
if make_dirs:
|
||||
print('Model name : ' + params['model_name'])
|
||||
print('Model file : ' + params['model_file_name'])
|
||||
print('Experiment : ' + params['experiment'])
|
||||
print("Model name : " + params["model_name"])
|
||||
print("Model file : " + params["model_file_name"])
|
||||
print("Experiment : " + params["experiment"])
|
||||
|
||||
mk_dir(params['experiment'])
|
||||
if params['save_test_image_during_train']:
|
||||
mk_dir(params['op_im_dir'])
|
||||
if params['save_test_image_after_train']:
|
||||
mk_dir(params['op_im_dir_test'])
|
||||
mk_dir(os.path.dirname(params['model_file_name']))
|
||||
mk_dir(params["experiment"])
|
||||
if params["save_test_image_during_train"]:
|
||||
mk_dir(params["op_im_dir"])
|
||||
if params["save_test_image_after_train"]:
|
||||
mk_dir(params["op_im_dir_test"])
|
||||
mk_dir(os.path.dirname(params["model_file_name"]))
|
||||
|
||||
return params
|
||||
|
@ -1,88 +1,168 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
"""Post-processing of the output of the model."""
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
np.seterr(divide='ignore', invalid='ignore')
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.detector.models import ModelOutput
|
||||
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults
|
||||
|
||||
np.seterr(divide="ignore", invalid="ignore")
|
||||
|
||||
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = int(fft_win_length*sampling_rate)
|
||||
noverlap = int(fft_overlap*nfft)
|
||||
return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate
|
||||
#return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
def x_coords_to_time(
|
||||
x_pos: float,
|
||||
sampling_rate: int,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> float:
|
||||
"""Convert x coordinates of spectrogram to time in seconds.
|
||||
|
||||
Args:
|
||||
x_pos: X position of the detection in pixels.
|
||||
sampling_rate: Sampling rate of the audio in Hz.
|
||||
fft_win_length: Length of the FFT window in seconds.
|
||||
fft_overlap: Overlap of the FFT windows in seconds.
|
||||
|
||||
Returns:
|
||||
Time in seconds.
|
||||
"""
|
||||
nfft = int(fft_win_length * sampling_rate)
|
||||
noverlap = int(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
|
||||
|
||||
def overall_class_pred(det_prob, class_prob):
|
||||
weighted_pred = (class_prob*det_prob).sum(1)
|
||||
weighted_pred = (class_prob * det_prob).sum(1)
|
||||
return weighted_pred / weighted_pred.sum()
|
||||
|
||||
|
||||
def run_nms(outputs, params, sampling_rate):
|
||||
def run_nms(
|
||||
outputs: ModelOutput,
|
||||
params: NonMaximumSuppressionConfig,
|
||||
sampling_rate: np.ndarray,
|
||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||
"""Run non-maximum suppression on the output of the model.
|
||||
|
||||
pred_det = outputs['pred_det'] # probability of box
|
||||
pred_size = outputs['pred_size'] # box size
|
||||
Model outputs processed are expected to have a batch dimension.
|
||||
Each element of the batch is processed independently. The
|
||||
result is a pair of lists, one for the predictions and one for
|
||||
the features. Each element of the lists corresponds to one
|
||||
element of the batch.
|
||||
"""
|
||||
pred_det, pred_size, pred_class, _, features = outputs
|
||||
|
||||
pred_det_nms = non_max_suppression(pred_det, params['nms_kernel_size'])
|
||||
freq_rescale = (params['max_freq'] - params['min_freq']) /pred_det.shape[-2]
|
||||
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
||||
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
||||
-2
|
||||
]
|
||||
|
||||
# NOTE there will be small differences depending on which sampling rate is chosen
|
||||
# as we are choosing the same sampling rate for the entire batch
|
||||
duration = x_coords_to_time(pred_det.shape[-1], sampling_rate[0].item(),
|
||||
params['fft_win_length'], params['fft_overlap'])
|
||||
top_k = int(duration * params['nms_top_k_per_sec'])
|
||||
# NOTE: there will be small differences depending on which sampling rate
|
||||
# is chosen as we are choosing the same sampling rate for the entire batch
|
||||
duration = x_coords_to_time(
|
||||
pred_det.shape[-1],
|
||||
int(sampling_rate[0].item()),
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
top_k = int(duration * params["nms_top_k_per_sec"])
|
||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||
|
||||
# loop over batch to save outputs
|
||||
preds = []
|
||||
feats = []
|
||||
for ii in range(pred_det_nms.shape[0]):
|
||||
preds: List[PredictionResults] = []
|
||||
feats: List[np.ndarray] = []
|
||||
for num_detection in range(pred_det_nms.shape[0]):
|
||||
# get valid indices
|
||||
inds_ord = torch.argsort(x_pos[ii, :])
|
||||
valid_inds = scores[ii, inds_ord] > params['detection_threshold']
|
||||
inds_ord = torch.argsort(x_pos[num_detection, :])
|
||||
valid_inds = (
|
||||
scores[num_detection, inds_ord] > params["detection_threshold"]
|
||||
)
|
||||
valid_inds = inds_ord[valid_inds]
|
||||
|
||||
# create result dictionary
|
||||
pred = {}
|
||||
pred['det_probs'] = scores[ii, valid_inds]
|
||||
pred['x_pos'] = x_pos[ii, valid_inds]
|
||||
pred['y_pos'] = y_pos[ii, valid_inds]
|
||||
pred['bb_width'] = pred_size[ii, 0, pred['y_pos'], pred['x_pos']]
|
||||
pred['bb_height'] = pred_size[ii, 1, pred['y_pos'], pred['x_pos']]
|
||||
pred['start_times'] = x_coords_to_time(pred['x_pos'].float() / params['resize_factor'],
|
||||
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap'])
|
||||
pred['end_times'] = x_coords_to_time((pred['x_pos'].float()+pred['bb_width']) / params['resize_factor'],
|
||||
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap'])
|
||||
pred['low_freqs'] = (pred_size[ii].shape[1] - pred['y_pos'].float())*freq_rescale + params['min_freq']
|
||||
pred['high_freqs'] = pred['low_freqs'] + pred['bb_height']*freq_rescale
|
||||
pred["det_probs"] = scores[num_detection, valid_inds]
|
||||
pred["x_pos"] = x_pos[num_detection, valid_inds]
|
||||
pred["y_pos"] = y_pos[num_detection, valid_inds]
|
||||
pred["bb_width"] = pred_size[
|
||||
num_detection,
|
||||
0,
|
||||
pred["y_pos"],
|
||||
pred["x_pos"],
|
||||
]
|
||||
pred["bb_height"] = pred_size[
|
||||
num_detection,
|
||||
1,
|
||||
pred["y_pos"],
|
||||
pred["x_pos"],
|
||||
]
|
||||
pred["start_times"] = x_coords_to_time(
|
||||
pred["x_pos"].float() / params["resize_factor"],
|
||||
int(sampling_rate[num_detection].item()),
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
pred["end_times"] = x_coords_to_time(
|
||||
(pred["x_pos"].float() + pred["bb_width"])
|
||||
/ params["resize_factor"],
|
||||
int(sampling_rate[num_detection].item()),
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
pred["low_freqs"] = (
|
||||
pred_size[num_detection].shape[1] - pred["y_pos"].float()
|
||||
) * freq_rescale + params["min_freq"]
|
||||
pred["high_freqs"] = (
|
||||
pred["low_freqs"] + pred["bb_height"] * freq_rescale
|
||||
)
|
||||
|
||||
# extract the per class votes
|
||||
if 'pred_class' in outputs:
|
||||
pred['class_probs'] = outputs['pred_class'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]]
|
||||
if pred_class is not None:
|
||||
pred["class_probs"] = pred_class[
|
||||
num_detection,
|
||||
:,
|
||||
y_pos[num_detection, valid_inds],
|
||||
x_pos[num_detection, valid_inds],
|
||||
]
|
||||
|
||||
# extract the model features
|
||||
if 'features' in outputs:
|
||||
feat = outputs['features'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]].transpose(0, 1)
|
||||
feat = feat.cpu().numpy().astype(np.float32)
|
||||
if features is not None:
|
||||
feat = features[
|
||||
num_detection,
|
||||
:,
|
||||
y_pos[num_detection, valid_inds],
|
||||
x_pos[num_detection, valid_inds],
|
||||
].transpose(0, 1)
|
||||
feat = feat.detach().numpy().astype(np.float32)
|
||||
feats.append(feat)
|
||||
|
||||
# convert to numpy
|
||||
for kk in pred.keys():
|
||||
pred[kk] = pred[kk].cpu().numpy().astype(np.float32)
|
||||
preds.append(pred)
|
||||
for key, value in pred.items():
|
||||
pred[key] = value.detach().numpy().astype(np.float32)
|
||||
|
||||
preds.append(pred) # type: ignore
|
||||
|
||||
return preds, feats
|
||||
|
||||
|
||||
def non_max_suppression(heat, kernel_size):
|
||||
def non_max_suppression(
|
||||
heat: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
):
|
||||
# kernel can be an int or list/tuple
|
||||
if type(kernel_size) is int:
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
||||
hmax = nn.functional.max_pool2d(heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w))
|
||||
hmax = nn.functional.max_pool2d(
|
||||
heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)
|
||||
)
|
||||
keep = (hmax == heat).float()
|
||||
|
||||
return heat * keep
|
||||
@ -94,7 +174,7 @@ def get_topk_scores(scores, K):
|
||||
|
||||
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
||||
topk_inds = topk_inds % (height * width)
|
||||
topk_ys = torch.div(topk_inds, width, rounding_mode='floor').long()
|
||||
topk_xs = (topk_inds % width).long()
|
||||
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
|
||||
topk_xs = (topk_inds % width).long()
|
||||
|
||||
return topk_scores, topk_ys, topk_xs
|
||||
|
0
bat_detect/evaluate/__init__.py
Normal file
0
bat_detect/evaluate/__init__.py
Normal file
@ -2,67 +2,74 @@
|
||||
Evaluates trained model on test set and generates plots.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
import argparse
|
||||
|
||||
sys.path.append('../../')
|
||||
import bat_detect.utils.detector_utils as du
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.detector.parameters as parameters
|
||||
from bat_detect.detector import parameters
|
||||
import bat_detect.train.evaluate as evl
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.utils.detector_utils as du
|
||||
import bat_detect.utils.plot_utils as pu
|
||||
|
||||
|
||||
def get_blank_annotation(ip_str):
|
||||
|
||||
res = {}
|
||||
res['class_name'] = ''
|
||||
res['duration'] = -1
|
||||
res['id'] = ''# fileName
|
||||
res['issues'] = False
|
||||
res['notes'] = ip_str
|
||||
res['time_exp'] = 1
|
||||
res['annotated'] = False
|
||||
res['annotation'] = []
|
||||
res["class_name"] = ""
|
||||
res["duration"] = -1
|
||||
res["id"] = "" # fileName
|
||||
res["issues"] = False
|
||||
res["notes"] = ip_str
|
||||
res["time_exp"] = 1
|
||||
res["annotated"] = False
|
||||
res["annotation"] = []
|
||||
|
||||
ann = {}
|
||||
ann['class'] = ''
|
||||
ann['event'] = 'Echolocation'
|
||||
ann['individual'] = -1
|
||||
ann['start_time'] = -1
|
||||
ann['end_time'] = -1
|
||||
ann['low_freq'] = -1
|
||||
ann['high_freq'] = -1
|
||||
ann['confidence'] = -1
|
||||
ann["class"] = ""
|
||||
ann["event"] = "Echolocation"
|
||||
ann["individual"] = -1
|
||||
ann["start_time"] = -1
|
||||
ann["end_time"] = -1
|
||||
ann["low_freq"] = -1
|
||||
ann["high_freq"] = -1
|
||||
ann["confidence"] = -1
|
||||
|
||||
return copy.deepcopy(res), copy.deepcopy(ann)
|
||||
|
||||
|
||||
def create_genus_mapping(gt_test, preds, class_names):
|
||||
# rolls the per class predictions and ground truth back up to genus level
|
||||
class_names_genus, cls_to_genus = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True)
|
||||
genus_to_cls_map = [np.where(np.array(cls_to_genus) == cc)[0] for cc in range(len(class_names_genus))]
|
||||
class_names_genus, cls_to_genus = np.unique(
|
||||
[cc.split(" ")[0] for cc in class_names], return_inverse=True
|
||||
)
|
||||
genus_to_cls_map = [
|
||||
np.where(np.array(cls_to_genus) == cc)[0]
|
||||
for cc in range(len(class_names_genus))
|
||||
]
|
||||
|
||||
gt_test_g = []
|
||||
for gg in gt_test:
|
||||
gg_g = copy.deepcopy(gg)
|
||||
inds = np.where(gg_g['class_ids']!=-1)[0]
|
||||
gg_g['class_ids'][inds] = cls_to_genus[gg_g['class_ids'][inds]]
|
||||
inds = np.where(gg_g["class_ids"] != -1)[0]
|
||||
gg_g["class_ids"][inds] = cls_to_genus[gg_g["class_ids"][inds]]
|
||||
gt_test_g.append(gg_g)
|
||||
|
||||
# note, will have entries geater than one as we are summing across the respective classes
|
||||
preds_g = []
|
||||
for pp in preds:
|
||||
pp_g = copy.deepcopy(pp)
|
||||
pp_g['class_probs'] = np.zeros((len(class_names_genus), pp_g['class_probs'].shape[1]), dtype=np.float32)
|
||||
pp_g["class_probs"] = np.zeros(
|
||||
(len(class_names_genus), pp_g["class_probs"].shape[1]),
|
||||
dtype=np.float32,
|
||||
)
|
||||
for cc, inds in enumerate(genus_to_cls_map):
|
||||
pp_g['class_probs'][cc, :] = pp['class_probs'][inds, :].sum(0)
|
||||
pp_g["class_probs"][cc, :] = pp["class_probs"][inds, :].sum(0)
|
||||
preds_g.append(pp_g)
|
||||
|
||||
return class_names_genus, preds_g, gt_test_g
|
||||
@ -70,56 +77,70 @@ def create_genus_mapping(gt_test, preds, class_names):
|
||||
|
||||
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
||||
|
||||
res, ann = get_blank_annotation('Generated by Tadarida')
|
||||
res, ann = get_blank_annotation("Generated by Tadarida")
|
||||
|
||||
# create the annotations in the correct format
|
||||
da_c = pd.read_csv(ip_dir + dataset + '/' + file_of_interest.replace('.wav', '.ta').replace('.WAV', '.ta'), sep='\t')
|
||||
da_c = pd.read_csv(
|
||||
ip_dir
|
||||
+ dataset
|
||||
+ "/"
|
||||
+ file_of_interest.replace(".wav", ".ta").replace(".WAV", ".ta"),
|
||||
sep="\t",
|
||||
)
|
||||
|
||||
res_c = copy.deepcopy(res)
|
||||
res_c['id'] = file_of_interest
|
||||
res_c['dataset'] = dataset
|
||||
res_c['feats'] = da_c.iloc[:, 6:].values.astype(np.float32)
|
||||
res_c["id"] = file_of_interest
|
||||
res_c["dataset"] = dataset
|
||||
res_c["feats"] = da_c.iloc[:, 6:].values.astype(np.float32)
|
||||
|
||||
if da_c.shape[0] > 0:
|
||||
res_c['class_name'] = ''
|
||||
res_c['class_prob'] = 0.0
|
||||
res_c["class_name"] = ""
|
||||
res_c["class_prob"] = 0.0
|
||||
|
||||
for aa in range(da_c.shape[0]):
|
||||
ann_c = copy.deepcopy(ann)
|
||||
ann_c['class'] = 'Not Bat' # will assign to class later
|
||||
ann_c['start_time'] = np.round(da_c.iloc[aa]['StTime']/1000.0 ,5)
|
||||
ann_c['end_time'] = np.round((da_c.iloc[aa]['StTime'] + da_c.iloc[aa]['Dur'])/1000.0, 5)
|
||||
ann_c['low_freq'] = np.round(da_c.iloc[aa]['Fmin'] * 1000.0, 2)
|
||||
ann_c['high_freq'] = np.round(da_c.iloc[aa]['Fmax'] * 1000.0, 2)
|
||||
ann_c['det_prob'] = 0.0
|
||||
res_c['annotation'].append(ann_c)
|
||||
ann_c["class"] = "Not Bat" # will assign to class later
|
||||
ann_c["start_time"] = np.round(da_c.iloc[aa]["StTime"] / 1000.0, 5)
|
||||
ann_c["end_time"] = np.round(
|
||||
(da_c.iloc[aa]["StTime"] + da_c.iloc[aa]["Dur"]) / 1000.0, 5
|
||||
)
|
||||
ann_c["low_freq"] = np.round(da_c.iloc[aa]["Fmin"] * 1000.0, 2)
|
||||
ann_c["high_freq"] = np.round(da_c.iloc[aa]["Fmax"] * 1000.0, 2)
|
||||
ann_c["det_prob"] = 0.0
|
||||
res_c["annotation"].append(ann_c)
|
||||
|
||||
return res_c
|
||||
|
||||
|
||||
def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_accepted_species=True):
|
||||
def load_sonobat_meta(
|
||||
ip_dir,
|
||||
datasets,
|
||||
region_classifier,
|
||||
class_names,
|
||||
only_accepted_species=True,
|
||||
):
|
||||
|
||||
sp_dict = {}
|
||||
for ss in class_names:
|
||||
sp_key = ss.split(' ')[0][:3] + ss.split(' ')[1][:3]
|
||||
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
||||
sp_dict[sp_key] = ss
|
||||
|
||||
sp_dict['x'] = '' # not bat
|
||||
sp_dict['Bat'] = 'Bat'
|
||||
sp_dict["x"] = "" # not bat
|
||||
sp_dict["Bat"] = "Bat"
|
||||
|
||||
sonobat_meta = {}
|
||||
for tt in datasets:
|
||||
dataset = tt['dataset_name']
|
||||
sb_ip_dir = ip_dir + dataset + '/' + region_classifier + '/'
|
||||
dataset = tt["dataset_name"]
|
||||
sb_ip_dir = ip_dir + dataset + "/" + region_classifier + "/"
|
||||
|
||||
# load the call level predictions
|
||||
ip_file_p = sb_ip_dir + dataset + '_Parameters_v4.5.0.txt'
|
||||
#ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt'
|
||||
da = pd.read_csv(ip_file_p, sep='\t')
|
||||
ip_file_p = sb_ip_dir + dataset + "_Parameters_v4.5.0.txt"
|
||||
# ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt'
|
||||
da = pd.read_csv(ip_file_p, sep="\t")
|
||||
|
||||
# load the file level predictions
|
||||
ip_file_b = sb_ip_dir + dataset + '_SonoBatch_v4.5.0.txt'
|
||||
#ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt'
|
||||
ip_file_b = sb_ip_dir + dataset + "_SonoBatch_v4.5.0.txt"
|
||||
# ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt'
|
||||
|
||||
with open(ip_file_b) as f:
|
||||
lines = f.readlines()
|
||||
@ -129,7 +150,7 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
|
||||
file_res = {}
|
||||
for ll in lines:
|
||||
# note this does not seem to parse the file very well
|
||||
ll_data = ll.split('\t')
|
||||
ll_data = ll.split("\t")
|
||||
|
||||
# there are sometimes many different species names per file
|
||||
if only_accepted_species:
|
||||
@ -137,20 +158,24 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
|
||||
ind = 4
|
||||
else:
|
||||
# choosing ""~Spp" if "SppAccp" does not exist
|
||||
if ll_data[4] != 'x':
|
||||
ind = 4 # choosing "SppAccp", along with "Prob" here
|
||||
if ll_data[4] != "x":
|
||||
ind = 4 # choosing "SppAccp", along with "Prob" here
|
||||
else:
|
||||
ind = 8 # choosing "~Spp", along with "~Prob" here
|
||||
|
||||
sp_name_1 = sp_dict[ll_data[ind]]
|
||||
prob_1 = ll_data[ind+1]
|
||||
if prob_1 == 'x':
|
||||
prob_1 = ll_data[ind + 1]
|
||||
if prob_1 == "x":
|
||||
prob_1 = 0.0
|
||||
file_res[ll_data[1]] = {'id':ll_data[1], 'species_1':sp_name_1, 'prob_1':prob_1}
|
||||
file_res[ll_data[1]] = {
|
||||
"id": ll_data[1],
|
||||
"species_1": sp_name_1,
|
||||
"prob_1": prob_1,
|
||||
}
|
||||
|
||||
sonobat_meta[dataset] = {}
|
||||
sonobat_meta[dataset]['file_res'] = file_res
|
||||
sonobat_meta[dataset]['call_info'] = da
|
||||
sonobat_meta[dataset]["file_res"] = file_res
|
||||
sonobat_meta[dataset]["call_info"] = da
|
||||
|
||||
return sonobat_meta
|
||||
|
||||
@ -158,34 +183,38 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
|
||||
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
# create the annotations in the correct format
|
||||
res, ann = get_blank_annotation('Generated by Sonobat')
|
||||
res, ann = get_blank_annotation("Generated by Sonobat")
|
||||
res_c = copy.deepcopy(res)
|
||||
res_c['id'] = id
|
||||
res_c['dataset'] = dataset
|
||||
res_c["id"] = id
|
||||
res_c["dataset"] = dataset
|
||||
|
||||
da = sb_meta[dataset]['call_info']
|
||||
da_c = da[da['Filename'] == id]
|
||||
da = sb_meta[dataset]["call_info"]
|
||||
da_c = da[da["Filename"] == id]
|
||||
|
||||
file_res = sb_meta[dataset]['file_res']
|
||||
res_c['feats'] = np.zeros((0,0))
|
||||
file_res = sb_meta[dataset]["file_res"]
|
||||
res_c["feats"] = np.zeros((0, 0))
|
||||
|
||||
if da_c.shape[0] > 0:
|
||||
res_c['class_name'] = file_res[id]['species_1']
|
||||
res_c['class_prob'] = file_res[id]['prob_1']
|
||||
res_c['feats'] = da_c.iloc[:, 3:105].values.astype(np.float32)
|
||||
res_c["class_name"] = file_res[id]["species_1"]
|
||||
res_c["class_prob"] = file_res[id]["prob_1"]
|
||||
res_c["feats"] = da_c.iloc[:, 3:105].values.astype(np.float32)
|
||||
|
||||
for aa in range(da_c.shape[0]):
|
||||
ann_c = copy.deepcopy(ann)
|
||||
if set_class_name is None:
|
||||
ann_c['class'] = file_res[id]['species_1']
|
||||
ann_c["class"] = file_res[id]["species_1"]
|
||||
else:
|
||||
ann_c['class'] = set_class_name
|
||||
ann_c['start_time'] = np.round(da_c.iloc[aa]['TimeInFile'] / 1000.0 ,5)
|
||||
ann_c['end_time'] = np.round(ann_c['start_time'] + da_c.iloc[aa]['CallDuration']/1000.0, 5)
|
||||
ann_c['low_freq'] = np.round(da_c.iloc[aa]['LowFreq'] * 1000.0, 2)
|
||||
ann_c['high_freq'] = np.round(da_c.iloc[aa]['HiFreq'] * 1000.0, 2)
|
||||
ann_c['det_prob'] = np.round(da_c.iloc[aa]['Quality'], 3)
|
||||
res_c['annotation'].append(ann_c)
|
||||
ann_c["class"] = set_class_name
|
||||
ann_c["start_time"] = np.round(
|
||||
da_c.iloc[aa]["TimeInFile"] / 1000.0, 5
|
||||
)
|
||||
ann_c["end_time"] = np.round(
|
||||
ann_c["start_time"] + da_c.iloc[aa]["CallDuration"] / 1000.0, 5
|
||||
)
|
||||
ann_c["low_freq"] = np.round(da_c.iloc[aa]["LowFreq"] * 1000.0, 2)
|
||||
ann_c["high_freq"] = np.round(da_c.iloc[aa]["HiFreq"] * 1000.0, 2)
|
||||
ann_c["det_prob"] = np.round(da_c.iloc[aa]["Quality"], 3)
|
||||
res_c["annotation"].append(ann_c)
|
||||
|
||||
return res_c
|
||||
|
||||
@ -193,8 +222,18 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
def bb_overlap(bb_g_in, bb_p_in):
|
||||
|
||||
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
||||
bb_g = [bb_g_in['start_time'], bb_g_in['low_freq']/freq_scale, bb_g_in['end_time'], bb_g_in['high_freq']/freq_scale]
|
||||
bb_p = [bb_p_in['start_time'], bb_p_in['low_freq']/freq_scale, bb_p_in['end_time'], bb_p_in['high_freq']/freq_scale]
|
||||
bb_g = [
|
||||
bb_g_in["start_time"],
|
||||
bb_g_in["low_freq"] / freq_scale,
|
||||
bb_g_in["end_time"],
|
||||
bb_g_in["high_freq"] / freq_scale,
|
||||
]
|
||||
bb_p = [
|
||||
bb_p_in["start_time"],
|
||||
bb_p_in["low_freq"] / freq_scale,
|
||||
bb_p_in["end_time"],
|
||||
bb_p_in["high_freq"] / freq_scale,
|
||||
]
|
||||
|
||||
xA = max(bb_g[0], bb_p[0])
|
||||
yA = max(bb_g[1], bb_p[1])
|
||||
@ -220,13 +259,15 @@ def bb_overlap(bb_g_in, bb_p_in):
|
||||
def assign_to_gt(gt, pred, iou_thresh):
|
||||
# this will edit pred in place
|
||||
|
||||
num_preds = len(pred['annotation'])
|
||||
num_gts = len(gt['annotation'])
|
||||
num_preds = len(pred["annotation"])
|
||||
num_gts = len(gt["annotation"])
|
||||
if num_preds > 0 and num_gts > 0:
|
||||
iou_m = np.zeros((num_preds, num_gts))
|
||||
for ii in range(num_preds):
|
||||
for jj in range(num_gts):
|
||||
iou_m[ii, jj] = bb_overlap(gt['annotation'][jj], pred['annotation'][ii])
|
||||
iou_m[ii, jj] = bb_overlap(
|
||||
gt["annotation"][jj], pred["annotation"][ii]
|
||||
)
|
||||
|
||||
# greedily assign detections to ground truths
|
||||
# needs to be greater than some threshold and we cannot assign GT
|
||||
@ -235,7 +276,9 @@ def assign_to_gt(gt, pred, iou_thresh):
|
||||
for jj in range(num_gts):
|
||||
max_iou = np.argmax(iou_m[:, jj])
|
||||
if iou_m[max_iou, jj] > iou_thresh:
|
||||
pred['annotation'][max_iou]['class'] = gt['annotation'][jj]['class']
|
||||
pred["annotation"][max_iou]["class"] = gt["annotation"][jj][
|
||||
"class"
|
||||
]
|
||||
iou_m[max_iou, :] = -1.0
|
||||
|
||||
return pred
|
||||
@ -244,27 +287,39 @@ def assign_to_gt(gt, pred, iou_thresh):
|
||||
def parse_data(data, class_names, non_event_classes, is_pred=False):
|
||||
class_names_all = class_names + non_event_classes
|
||||
|
||||
data['class_names'] = np.array([aa['class'] for aa in data['annotation']])
|
||||
data['start_times'] = np.array([aa['start_time'] for aa in data['annotation']])
|
||||
data['end_times'] = np.array([aa['end_time'] for aa in data['annotation']])
|
||||
data['high_freqs'] = np.array([float(aa['high_freq']) for aa in data['annotation']])
|
||||
data['low_freqs'] = np.array([float(aa['low_freq']) for aa in data['annotation']])
|
||||
data["class_names"] = np.array([aa["class"] for aa in data["annotation"]])
|
||||
data["start_times"] = np.array(
|
||||
[aa["start_time"] for aa in data["annotation"]]
|
||||
)
|
||||
data["end_times"] = np.array([aa["end_time"] for aa in data["annotation"]])
|
||||
data["high_freqs"] = np.array(
|
||||
[float(aa["high_freq"]) for aa in data["annotation"]]
|
||||
)
|
||||
data["low_freqs"] = np.array(
|
||||
[float(aa["low_freq"]) for aa in data["annotation"]]
|
||||
)
|
||||
|
||||
if is_pred:
|
||||
# when loading predictions
|
||||
data['det_probs'] = np.array([float(aa['det_prob']) for aa in data['annotation']])
|
||||
data['class_probs'] = np.zeros((len(class_names)+1, len(data['annotation'])))
|
||||
data['class_ids'] = np.array([class_names_all.index(aa['class']) for aa in data['annotation']]).astype(np.int32)
|
||||
data["det_probs"] = np.array(
|
||||
[float(aa["det_prob"]) for aa in data["annotation"]]
|
||||
)
|
||||
data["class_probs"] = np.zeros(
|
||||
(len(class_names) + 1, len(data["annotation"]))
|
||||
)
|
||||
data["class_ids"] = np.array(
|
||||
[class_names_all.index(aa["class"]) for aa in data["annotation"]]
|
||||
).astype(np.int32)
|
||||
else:
|
||||
# when loading ground truth
|
||||
# if the class label is not in the set of interest then set to -1
|
||||
labels = []
|
||||
for aa in data['annotation']:
|
||||
if aa['class'] in class_names:
|
||||
labels.append(class_names_all.index(aa['class']))
|
||||
for aa in data["annotation"]:
|
||||
if aa["class"] in class_names:
|
||||
labels.append(class_names_all.index(aa["class"]))
|
||||
else:
|
||||
labels.append(-1)
|
||||
data['class_ids'] = np.array(labels).astype(np.int32)
|
||||
data["class_ids"] = np.array(labels).astype(np.int32)
|
||||
|
||||
return data
|
||||
|
||||
@ -272,12 +327,17 @@ def parse_data(data, class_names, non_event_classes, is_pred=False):
|
||||
def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
|
||||
gt_data = []
|
||||
for dd in datasets:
|
||||
print('\n' + dd['dataset_name'])
|
||||
gt_dataset = tu.load_set_of_anns([dd], events_of_interest=events_of_interest, verbose=True)
|
||||
gt_dataset = [parse_data(gg, class_names, classes_to_ignore, False) for gg in gt_dataset]
|
||||
print("\n" + dd["dataset_name"])
|
||||
gt_dataset = tu.load_set_of_anns(
|
||||
[dd], events_of_interest=events_of_interest, verbose=True
|
||||
)
|
||||
gt_dataset = [
|
||||
parse_data(gg, class_names, classes_to_ignore, False)
|
||||
for gg in gt_dataset
|
||||
]
|
||||
|
||||
for gt in gt_dataset:
|
||||
gt['dataset_name'] = dd['dataset_name']
|
||||
gt["dataset_name"] = dd["dataset_name"]
|
||||
|
||||
gt_data.extend(gt_dataset)
|
||||
|
||||
@ -300,69 +360,103 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
|
||||
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
||||
clf.fit(x_train, y_train)
|
||||
y_pred = clf.predict(x_train)
|
||||
tr_acc = (y_pred==y_train).mean()
|
||||
#print('Train acc', round(tr_acc*100, 2))
|
||||
tr_acc = (y_pred == y_train).mean()
|
||||
# print('Train acc', round(tr_acc*100, 2))
|
||||
return clf, un_train_class
|
||||
|
||||
|
||||
def eval_rf_model(clf, pred, un_train_class, num_classes):
|
||||
# stores the prediction in place
|
||||
if pred['feats'].shape[0] > 0:
|
||||
pred['class_probs'] = np.zeros((num_classes, pred['feats'].shape[0]))
|
||||
pred['class_probs'][un_train_class, :] = clf.predict_proba(pred['feats']).T
|
||||
pred['det_probs'] = pred['class_probs'][:-1, :].sum(0)
|
||||
if pred["feats"].shape[0] > 0:
|
||||
pred["class_probs"] = np.zeros((num_classes, pred["feats"].shape[0]))
|
||||
pred["class_probs"][un_train_class, :] = clf.predict_proba(
|
||||
pred["feats"]
|
||||
).T
|
||||
pred["det_probs"] = pred["class_probs"][:-1, :].sum(0)
|
||||
else:
|
||||
pred['class_probs'] = np.zeros((num_classes, 0))
|
||||
pred['det_probs'] = np.zeros(0)
|
||||
pred["class_probs"] = np.zeros((num_classes, 0))
|
||||
pred["det_probs"] = np.zeros(0)
|
||||
return pred
|
||||
|
||||
|
||||
def save_summary_to_json(op_dir, mod_name, results):
|
||||
op = {}
|
||||
op['avg_prec'] = round(results['avg_prec'], 3)
|
||||
op['avg_prec_class'] = round(results['avg_prec_class'], 3)
|
||||
op['top_class'] = round(results['top_class']['avg_prec'], 3)
|
||||
op['file_acc'] = round(results['file_acc'], 3)
|
||||
op['model'] = mod_name
|
||||
op["avg_prec"] = round(results["avg_prec"], 3)
|
||||
op["avg_prec_class"] = round(results["avg_prec_class"], 3)
|
||||
op["top_class"] = round(results["top_class"]["avg_prec"], 3)
|
||||
op["file_acc"] = round(results["file_acc"], 3)
|
||||
op["model"] = mod_name
|
||||
|
||||
op['per_class'] = {}
|
||||
for cc in results['class_pr']:
|
||||
op['per_class'][cc['name']] = cc['avg_prec']
|
||||
op["per_class"] = {}
|
||||
for cc in results["class_pr"]:
|
||||
op["per_class"][cc["name"]] = cc["avg_prec"]
|
||||
|
||||
op_file_name = os.path.join(op_dir, mod_name + '_results.json')
|
||||
with open(op_file_name, 'w') as da:
|
||||
op_file_name = os.path.join(op_dir, mod_name + "_results.json")
|
||||
with open(op_file_name, "w") as da:
|
||||
json.dump(op, da, indent=2)
|
||||
|
||||
|
||||
def print_results(model_name, mod_str, results, op_dir, class_names, file_type, title_text=''):
|
||||
print('\nResults - ' + model_name)
|
||||
print('avg_prec ', round(results['avg_prec'], 3))
|
||||
print('avg_prec_class', round(results['avg_prec_class'], 3))
|
||||
print('top_class ', round(results['top_class']['avg_prec'], 3))
|
||||
print('file_acc ', round(results['file_acc'], 3))
|
||||
def print_results(
|
||||
model_name, mod_str, results, op_dir, class_names, file_type, title_text=""
|
||||
):
|
||||
print("\nResults - " + model_name)
|
||||
print("avg_prec ", round(results["avg_prec"], 3))
|
||||
print("avg_prec_class", round(results["avg_prec_class"], 3))
|
||||
print("top_class ", round(results["top_class"]["avg_prec"], 3))
|
||||
print("file_acc ", round(results["file_acc"], 3))
|
||||
|
||||
print('\nSaving ' + model_name + ' results to: ' + op_dir)
|
||||
print("\nSaving " + model_name + " results to: " + op_dir)
|
||||
save_summary_to_json(op_dir, mod_str, results)
|
||||
|
||||
pu.plot_pr_curve(op_dir, mod_str+'_test_all_det', mod_str+'_test_all_det', results, file_type, title_text + 'Detection PR')
|
||||
pu.plot_pr_curve(op_dir, mod_str+'_test_all_top_class', mod_str+'_test_all_top_class', results['top_class'], file_type, title_text + 'Top Class')
|
||||
pu.plot_pr_curve_class(op_dir, mod_str+'_test_all_class', mod_str+'_test_all_class', results, file_type, title_text + 'Per-Class PR')
|
||||
pu.plot_confusion_matrix(op_dir, mod_str+'_confusion', results['gt_valid_file'], results['pred_valid_file'],
|
||||
results['file_acc'], class_names, True, file_type, title_text + 'Confusion Matrix')
|
||||
pu.plot_pr_curve(
|
||||
op_dir,
|
||||
mod_str + "_test_all_det",
|
||||
mod_str + "_test_all_det",
|
||||
results,
|
||||
file_type,
|
||||
title_text + "Detection PR",
|
||||
)
|
||||
pu.plot_pr_curve(
|
||||
op_dir,
|
||||
mod_str + "_test_all_top_class",
|
||||
mod_str + "_test_all_top_class",
|
||||
results["top_class"],
|
||||
file_type,
|
||||
title_text + "Top Class",
|
||||
)
|
||||
pu.plot_pr_curve_class(
|
||||
op_dir,
|
||||
mod_str + "_test_all_class",
|
||||
mod_str + "_test_all_class",
|
||||
results,
|
||||
file_type,
|
||||
title_text + "Per-Class PR",
|
||||
)
|
||||
pu.plot_confusion_matrix(
|
||||
op_dir,
|
||||
mod_str + "_confusion",
|
||||
results["gt_valid_file"],
|
||||
results["pred_valid_file"],
|
||||
results["file_acc"],
|
||||
class_names,
|
||||
True,
|
||||
file_type,
|
||||
title_text + "Confusion Matrix",
|
||||
)
|
||||
|
||||
|
||||
def add_root_path_back(data_sets, ann_path, wav_path):
|
||||
for dd in data_sets:
|
||||
dd['ann_path'] = os.path.join(ann_path, dd['ann_path'])
|
||||
dd['wav_path'] = os.path.join(wav_path, dd['wav_path'])
|
||||
dd["ann_path"] = os.path.join(ann_path, dd["ann_path"])
|
||||
dd["wav_path"] = os.path.join(wav_path, dd["wav_path"])
|
||||
return data_sets
|
||||
|
||||
|
||||
def check_classes_in_train(gt_list, class_names):
|
||||
num_gt_total = np.sum([gg['start_times'].shape[0] for gg in gt_list])
|
||||
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||
num_with_no_class = 0
|
||||
for gt in gt_list:
|
||||
for cc in gt['class_names']:
|
||||
for cc in gt["class_names"]:
|
||||
if cc not in class_names:
|
||||
num_with_no_class += 1
|
||||
return num_with_no_class
|
||||
@ -371,195 +465,337 @@ def check_classes_in_train(gt_list, class_names):
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('op_dir', type=str, default='plots/results_compare/',
|
||||
help='Output directory for plots')
|
||||
parser.add_argument('data_dir', type=str,
|
||||
help='Path to root of datasets')
|
||||
parser.add_argument('ann_dir', type=str,
|
||||
help='Path to extracted annotations')
|
||||
parser.add_argument('bd_model_path', type=str,
|
||||
help='Path to BatDetect model')
|
||||
parser.add_argument('--test_file', type=str, default='',
|
||||
help='Path to json file used for evaluation.')
|
||||
parser.add_argument('--sb_ip_dir', type=str, default='',
|
||||
help='Path to sonobat predictions')
|
||||
parser.add_argument('--sb_region_classifier', type=str, default='south',
|
||||
help='Path to sonobat predictions')
|
||||
parser.add_argument('--td_ip_dir', type=str, default='',
|
||||
help='Path to tadarida_D predictions')
|
||||
parser.add_argument('--iou_thresh', type=float, default=0.01,
|
||||
help='IOU threshold for assigning predictions to ground truth')
|
||||
parser.add_argument('--file_type', type=str, default='png',
|
||||
help='Type of image to save - png or pdf')
|
||||
parser.add_argument('--title_text', type=str, default='',
|
||||
help='Text to add as title of plots')
|
||||
parser.add_argument('--rand_seed', type=int, default=2001,
|
||||
help='Random seed')
|
||||
parser.add_argument(
|
||||
"op_dir",
|
||||
type=str,
|
||||
default="plots/results_compare/",
|
||||
help="Output directory for plots",
|
||||
)
|
||||
parser.add_argument("data_dir", type=str, help="Path to root of datasets")
|
||||
parser.add_argument(
|
||||
"ann_dir", type=str, help="Path to extracted annotations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"bd_model_path", type=str, help="Path to BatDetect model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json file used for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sb_ip_dir", type=str, default="", help="Path to sonobat predictions"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sb_region_classifier",
|
||||
type=str,
|
||||
default="south",
|
||||
help="Path to sonobat predictions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--td_ip_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to tadarida_D predictions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iou_thresh",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="IOU threshold for assigning predictions to ground truth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file_type",
|
||||
type=str,
|
||||
default="png",
|
||||
help="Type of image to save - png or pdf",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--title_text",
|
||||
type=str,
|
||||
default="",
|
||||
help="Text to add as title of plots",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rand_seed", type=int, default=2001, help="Random seed"
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
np.random.seed(args['rand_seed'])
|
||||
|
||||
if not os.path.isdir(args['op_dir']):
|
||||
os.makedirs(args['op_dir'])
|
||||
np.random.seed(args["rand_seed"])
|
||||
|
||||
if not os.path.isdir(args["op_dir"]):
|
||||
os.makedirs(args["op_dir"])
|
||||
|
||||
# load the model
|
||||
params_eval = parameters.get_params(False)
|
||||
_, params_bd = du.load_model(args['bd_model_path'])
|
||||
_, params_bd = du.load_model(args["bd_model_path"])
|
||||
|
||||
class_names = params_bd['class_names']
|
||||
class_names = params_bd["class_names"]
|
||||
num_classes = len(class_names) + 1 # num classes plus background class
|
||||
|
||||
classes_to_ignore = ['Not Bat', 'Bat', 'Unknown']
|
||||
events_of_interest = ['Echolocation']
|
||||
classes_to_ignore = ["Not Bat", "Bat", "Unknown"]
|
||||
events_of_interest = ["Echolocation"]
|
||||
|
||||
# load test data
|
||||
if args['test_file'] == '':
|
||||
if args["test_file"] == "":
|
||||
# load the test files of interest from the trained model
|
||||
test_sets = add_root_path_back(params_bd['test_sets'], args['ann_dir'], args['data_dir'])
|
||||
test_sets = [dd for dd in test_sets if not dd['is_binary']] # exclude bat/not datasets
|
||||
test_sets = add_root_path_back(
|
||||
params_bd["test_sets"], args["ann_dir"], args["data_dir"]
|
||||
)
|
||||
test_sets = [
|
||||
dd for dd in test_sets if not dd["is_binary"]
|
||||
] # exclude bat/not datasets
|
||||
else:
|
||||
# user specified annotation file to evaluate
|
||||
test_dict = {}
|
||||
test_dict['dataset_name'] = args['test_file'].replace('.json', '')
|
||||
test_dict['is_test'] = True
|
||||
test_dict['is_binary'] = True
|
||||
test_dict['ann_path'] = os.path.join(args['ann_dir'], args['test_file'])
|
||||
test_dict['wav_path'] = args['data_dir']
|
||||
test_dict["dataset_name"] = args["test_file"].replace(".json", "")
|
||||
test_dict["is_test"] = True
|
||||
test_dict["is_binary"] = True
|
||||
test_dict["ann_path"] = os.path.join(args["ann_dir"], args["test_file"])
|
||||
test_dict["wav_path"] = args["data_dir"]
|
||||
test_sets = [test_dict]
|
||||
|
||||
# load the gt for the test set
|
||||
gt_test = load_gt_data(test_sets, events_of_interest, class_names, classes_to_ignore)
|
||||
total_num_calls = np.sum([gg['start_times'].shape[0] for gg in gt_test])
|
||||
print('\nTotal number of test files:', len(gt_test))
|
||||
print('Total number of test calls:', np.sum([gg['start_times'].shape[0] for gg in gt_test]))
|
||||
gt_test = load_gt_data(
|
||||
test_sets, events_of_interest, class_names, classes_to_ignore
|
||||
)
|
||||
total_num_calls = np.sum([gg["start_times"].shape[0] for gg in gt_test])
|
||||
print("\nTotal number of test files:", len(gt_test))
|
||||
print(
|
||||
"Total number of test calls:",
|
||||
np.sum([gg["start_times"].shape[0] for gg in gt_test]),
|
||||
)
|
||||
|
||||
# check if test contains classes not in the train set
|
||||
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
||||
if total_num_calls == num_with_no_class:
|
||||
print('Classes from the test set are not in the train set.')
|
||||
print("Classes from the test set are not in the train set.")
|
||||
assert False
|
||||
|
||||
# only need the train data if evaluating Sonobat or Tadarida
|
||||
if args['sb_ip_dir'] != '' or args['td_ip_dir'] != '':
|
||||
train_sets = add_root_path_back(params_bd['train_sets'], args['ann_dir'], args['data_dir'])
|
||||
train_sets = [dd for dd in train_sets if not dd['is_binary']] # exclude bat/not datasets
|
||||
gt_train = load_gt_data(train_sets, events_of_interest, class_names, classes_to_ignore)
|
||||
|
||||
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
||||
train_sets = add_root_path_back(
|
||||
params_bd["train_sets"], args["ann_dir"], args["data_dir"]
|
||||
)
|
||||
train_sets = [
|
||||
dd for dd in train_sets if not dd["is_binary"]
|
||||
] # exclude bat/not datasets
|
||||
gt_train = load_gt_data(
|
||||
train_sets, events_of_interest, class_names, classes_to_ignore
|
||||
)
|
||||
|
||||
#
|
||||
# evaluate Sonobat by training random forest classifier
|
||||
#
|
||||
# NOTE: Sonobat may only make predictions for a subset of the files
|
||||
#
|
||||
if args['sb_ip_dir'] != '':
|
||||
sb_meta = load_sonobat_meta(args['sb_ip_dir'], train_sets + test_sets, args['sb_region_classifier'], class_names)
|
||||
if args["sb_ip_dir"] != "":
|
||||
sb_meta = load_sonobat_meta(
|
||||
args["sb_ip_dir"],
|
||||
train_sets + test_sets,
|
||||
args["sb_region_classifier"],
|
||||
class_names,
|
||||
)
|
||||
|
||||
preds_sb = []
|
||||
keep_inds_sb = []
|
||||
for ii, gt in enumerate(gt_test):
|
||||
sb_pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta)
|
||||
if sb_pred['class_name'] != '':
|
||||
sb_pred = parse_data(sb_pred, class_names, classes_to_ignore, True)
|
||||
sb_pred['class_probs'][sb_pred['class_ids'], np.arange(sb_pred['class_probs'].shape[1])] = sb_pred['det_probs']
|
||||
sb_pred = load_sonobat_preds(gt["dataset_name"], gt["id"], sb_meta)
|
||||
if sb_pred["class_name"] != "":
|
||||
sb_pred = parse_data(
|
||||
sb_pred, class_names, classes_to_ignore, True
|
||||
)
|
||||
sb_pred["class_probs"][
|
||||
sb_pred["class_ids"],
|
||||
np.arange(sb_pred["class_probs"].shape[1]),
|
||||
] = sb_pred["det_probs"]
|
||||
preds_sb.append(sb_pred)
|
||||
keep_inds_sb.append(ii)
|
||||
|
||||
results_sb = evl.evaluate_predictions([gt_test[ii] for ii in keep_inds_sb], preds_sb, class_names,
|
||||
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
||||
print_results('Sonobat', 'sb', results_sb, args['op_dir'], class_names,
|
||||
args['file_type'], args['title_text'] + ' - Species - ')
|
||||
print('Only reporting results for', len(keep_inds_sb), 'files, out of', len(gt_test))
|
||||
|
||||
results_sb = evl.evaluate_predictions(
|
||||
[gt_test[ii] for ii in keep_inds_sb],
|
||||
preds_sb,
|
||||
class_names,
|
||||
params_eval["detection_overlap"],
|
||||
params_eval["ignore_start_end"],
|
||||
)
|
||||
print_results(
|
||||
"Sonobat",
|
||||
"sb",
|
||||
results_sb,
|
||||
args["op_dir"],
|
||||
class_names,
|
||||
args["file_type"],
|
||||
args["title_text"] + " - Species - ",
|
||||
)
|
||||
print(
|
||||
"Only reporting results for",
|
||||
len(keep_inds_sb),
|
||||
"files, out of",
|
||||
len(gt_test),
|
||||
)
|
||||
|
||||
# train our own random forest on sonobat features
|
||||
x_train = []
|
||||
y_train = []
|
||||
for gt in gt_train:
|
||||
pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat')
|
||||
pred = load_sonobat_preds(
|
||||
gt["dataset_name"], gt["id"], sb_meta, "Not Bat"
|
||||
)
|
||||
|
||||
if len(pred['annotation']) > 0:
|
||||
if len(pred["annotation"]) > 0:
|
||||
# compute detection overlap with ground truth to determine which are the TP detections
|
||||
assign_to_gt(gt, pred, args['iou_thresh'])
|
||||
assign_to_gt(gt, pred, args["iou_thresh"])
|
||||
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
||||
x_train.append(pred['feats'])
|
||||
y_train.append(pred['class_ids'])
|
||||
x_train.append(pred["feats"])
|
||||
y_train.append(pred["class_ids"])
|
||||
|
||||
# train random forest on tadarida predictions
|
||||
clf_sb, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed'])
|
||||
clf_sb, un_train_class = train_rf_model(
|
||||
x_train, y_train, num_classes, args["rand_seed"]
|
||||
)
|
||||
|
||||
# run the model on the test set
|
||||
preds_sb_rf = []
|
||||
for gt in gt_test:
|
||||
pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat')
|
||||
pred = load_sonobat_preds(
|
||||
gt["dataset_name"], gt["id"], sb_meta, "Not Bat"
|
||||
)
|
||||
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
||||
pred = eval_rf_model(clf_sb, pred, un_train_class, num_classes)
|
||||
preds_sb_rf.append(pred)
|
||||
|
||||
results_sb_rf = evl.evaluate_predictions(gt_test, preds_sb_rf, class_names,
|
||||
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
||||
print_results('Sonobat RF', 'sb_rf', results_sb_rf, args['op_dir'], class_names,
|
||||
args['file_type'], args['title_text'] + ' - Species - ')
|
||||
print('\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n')
|
||||
|
||||
results_sb_rf = evl.evaluate_predictions(
|
||||
gt_test,
|
||||
preds_sb_rf,
|
||||
class_names,
|
||||
params_eval["detection_overlap"],
|
||||
params_eval["ignore_start_end"],
|
||||
)
|
||||
print_results(
|
||||
"Sonobat RF",
|
||||
"sb_rf",
|
||||
results_sb_rf,
|
||||
args["op_dir"],
|
||||
class_names,
|
||||
args["file_type"],
|
||||
args["title_text"] + " - Species - ",
|
||||
)
|
||||
print(
|
||||
"\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n"
|
||||
)
|
||||
|
||||
#
|
||||
# evaluate Tadarida-D by training random forest classifier
|
||||
#
|
||||
if args['td_ip_dir'] != '':
|
||||
if args["td_ip_dir"] != "":
|
||||
x_train = []
|
||||
y_train = []
|
||||
for gt in gt_train:
|
||||
pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id'])
|
||||
pred = load_tadarida_pred(
|
||||
args["td_ip_dir"], gt["dataset_name"], gt["id"]
|
||||
)
|
||||
# compute detection overlap with ground truth to determine which are the TP detections
|
||||
assign_to_gt(gt, pred, args['iou_thresh'])
|
||||
assign_to_gt(gt, pred, args["iou_thresh"])
|
||||
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
||||
x_train.append(pred['feats'])
|
||||
y_train.append(pred['class_ids'])
|
||||
x_train.append(pred["feats"])
|
||||
y_train.append(pred["class_ids"])
|
||||
|
||||
# train random forest on Tadarida-D predictions
|
||||
clf_td, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed'])
|
||||
clf_td, un_train_class = train_rf_model(
|
||||
x_train, y_train, num_classes, args["rand_seed"]
|
||||
)
|
||||
|
||||
# run the model on the test set
|
||||
preds_td = []
|
||||
for gt in gt_test:
|
||||
pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id'])
|
||||
pred = load_tadarida_pred(
|
||||
args["td_ip_dir"], gt["dataset_name"], gt["id"]
|
||||
)
|
||||
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
||||
pred = eval_rf_model(clf_td, pred, un_train_class, num_classes)
|
||||
preds_td.append(pred)
|
||||
|
||||
results_td = evl.evaluate_predictions(gt_test, preds_td, class_names,
|
||||
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
||||
print_results('Tadarida', 'td_rf', results_td, args['op_dir'], class_names,
|
||||
args['file_type'], args['title_text'] + ' - Species - ')
|
||||
|
||||
results_td = evl.evaluate_predictions(
|
||||
gt_test,
|
||||
preds_td,
|
||||
class_names,
|
||||
params_eval["detection_overlap"],
|
||||
params_eval["ignore_start_end"],
|
||||
)
|
||||
print_results(
|
||||
"Tadarida",
|
||||
"td_rf",
|
||||
results_td,
|
||||
args["op_dir"],
|
||||
class_names,
|
||||
args["file_type"],
|
||||
args["title_text"] + " - Species - ",
|
||||
)
|
||||
|
||||
#
|
||||
# evaluate BatDetect
|
||||
#
|
||||
if args['bd_model_path'] != '':
|
||||
if args["bd_model_path"] != "":
|
||||
# load model
|
||||
bd_args = du.get_default_bd_args()
|
||||
model, params_bd = du.load_model(args['bd_model_path'])
|
||||
bd_args = du.get_default_run_config()
|
||||
model, params_bd = du.load_model(args["bd_model_path"])
|
||||
|
||||
# check if the class names are the same
|
||||
if params_bd['class_names'] != class_names:
|
||||
print('Warning: Class names are not the same as the trained model')
|
||||
if params_bd["class_names"] != class_names:
|
||||
print("Warning: Class names are not the same as the trained model")
|
||||
assert False
|
||||
|
||||
run_config = {
|
||||
**bd_args,
|
||||
**params_bd,
|
||||
"return_raw_preds": True,
|
||||
}
|
||||
|
||||
preds_bd = []
|
||||
for ii, gg in enumerate(gt_test):
|
||||
pred = du.process_file(gg['file_path'], model, params_bd, bd_args, return_raw_preds=True)
|
||||
pred = du.process_file(
|
||||
gg["file_path"],
|
||||
model,
|
||||
run_config,
|
||||
)
|
||||
preds_bd.append(pred)
|
||||
|
||||
results_bd = evl.evaluate_predictions(gt_test, preds_bd, class_names,
|
||||
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
||||
print_results('BatDetect', 'bd', results_bd, args['op_dir'],
|
||||
class_names, args['file_type'], args['title_text'] + ' - Species - ')
|
||||
results_bd = evl.evaluate_predictions(
|
||||
gt_test,
|
||||
preds_bd,
|
||||
class_names,
|
||||
params_eval["detection_overlap"],
|
||||
params_eval["ignore_start_end"],
|
||||
)
|
||||
print_results(
|
||||
"BatDetect",
|
||||
"bd",
|
||||
results_bd,
|
||||
args["op_dir"],
|
||||
class_names,
|
||||
args["file_type"],
|
||||
args["title_text"] + " - Species - ",
|
||||
)
|
||||
|
||||
# evaluate genus level
|
||||
class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(gt_test, preds_bd, class_names)
|
||||
results_bd_genus = evl.evaluate_predictions(gt_test_g, preds_bd_g, class_names_genus,
|
||||
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
||||
print_results('BatDetect Genus', 'bd_genus', results_bd_genus, args['op_dir'],
|
||||
class_names_genus, args['file_type'], args['title_text'] + ' - Genus - ')
|
||||
class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(
|
||||
gt_test, preds_bd, class_names
|
||||
)
|
||||
results_bd_genus = evl.evaluate_predictions(
|
||||
gt_test_g,
|
||||
preds_bd_g,
|
||||
class_names_genus,
|
||||
params_eval["detection_overlap"],
|
||||
params_eval["ignore_start_end"],
|
||||
)
|
||||
print_results(
|
||||
"BatDetect Genus",
|
||||
"bd_genus",
|
||||
results_bd_genus,
|
||||
args["op_dir"],
|
||||
class_names_genus,
|
||||
args["file_type"],
|
||||
args["title_text"] + " - Genus - ",
|
||||
)
|
||||
|
0
bat_detect/finetune/__init__.py
Normal file
0
bat_detect/finetune/__init__.py
Normal file
@ -1,183 +1,321 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
import json
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import sys
|
||||
sys.path.append(os.path.join('..', '..'))
|
||||
import bat_detect.train.train_model as tm
|
||||
sys.path.append(os.path.join("..", ".."))
|
||||
import bat_detect.detector.models as models
|
||||
import bat_detect.detector.parameters as parameters
|
||||
import bat_detect.detector.post_process as pp
|
||||
import bat_detect.train.audio_dataloader as adl
|
||||
import bat_detect.train.evaluate as evl
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.train.losses as losses
|
||||
|
||||
import bat_detect.detector.parameters as parameters
|
||||
import bat_detect.detector.models as models
|
||||
import bat_detect.detector.post_process as pp
|
||||
import bat_detect.utils.plot_utils as pu
|
||||
import bat_detect.train.train_model as tm
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.utils.detector_utils as du
|
||||
|
||||
import bat_detect.utils.plot_utils as pu
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
info_str = '\nBatDetect - Finetune Model\n'
|
||||
info_str = "\nBatDetect - Finetune Model\n"
|
||||
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('audio_path', type=str, help='Input directory for audio')
|
||||
parser.add_argument('train_ann_path', type=str,
|
||||
help='Path to where train annotation file is stored')
|
||||
parser.add_argument('test_ann_path', type=str,
|
||||
help='Path to where test annotation file is stored')
|
||||
parser.add_argument('model_path', type=str,
|
||||
help='Path to pretrained model')
|
||||
parser.add_argument('--op_model_name', type=str, default='',
|
||||
help='Path and name for finetuned model')
|
||||
parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs',
|
||||
help='Number of finetuning epochs')
|
||||
parser.add_argument('--finetune_only_last_layer', action='store_true',
|
||||
help='Only train final layers')
|
||||
parser.add_argument('--train_from_scratch', action='store_true',
|
||||
help='Do not use pretrained weights')
|
||||
parser.add_argument('--do_not_save_images', action='store_false',
|
||||
help='Do not save images at the end of training')
|
||||
parser.add_argument('--notes', type=str, default='',
|
||||
help='Notes to save in text file')
|
||||
parser.add_argument(
|
||||
"audio_path", type=str, help="Input directory for audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"train_ann_path",
|
||||
type=str,
|
||||
help="Path to where train annotation file is stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"test_ann_path",
|
||||
type=str,
|
||||
help="Path to where test annotation file is stored",
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to pretrained model")
|
||||
parser.add_argument(
|
||||
"--op_model_name",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path and name for finetuned model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
default=200,
|
||||
dest="num_epochs",
|
||||
help="Number of finetuning epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune_only_last_layer",
|
||||
action="store_true",
|
||||
help="Only train final layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_from_scratch",
|
||||
action="store_true",
|
||||
help="Do not use pretrained weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_not_save_images",
|
||||
action="store_false",
|
||||
help="Do not save images at the end of training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--notes", type=str, default="", help="Notes to save in text file"
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
params = parameters.get_params(True, '../../experiments/')
|
||||
params = parameters.get_params(True, "../../experiments/")
|
||||
if torch.cuda.is_available():
|
||||
params['device'] = 'cuda'
|
||||
params["device"] = "cuda"
|
||||
else:
|
||||
params['device'] = 'cpu'
|
||||
print('\nNote, this will be a lot faster if you use computer with a GPU.\n')
|
||||
params["device"] = "cpu"
|
||||
print(
|
||||
"\nNote, this will be a lot faster if you use computer with a GPU.\n"
|
||||
)
|
||||
|
||||
print('\nAudio directory: ' + args['audio_path'])
|
||||
print('Train file: ' + args['train_ann_path'])
|
||||
print('Test file: ' + args['test_ann_path'])
|
||||
print('Loading model: ' + args['model_path'])
|
||||
print("\nAudio directory: " + args["audio_path"])
|
||||
print("Train file: " + args["train_ann_path"])
|
||||
print("Test file: " + args["test_ann_path"])
|
||||
print("Loading model: " + args["model_path"])
|
||||
|
||||
dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '')
|
||||
dataset_name = (
|
||||
os.path.basename(args["train_ann_path"])
|
||||
.replace(".json", "")
|
||||
.replace("_TRAIN", "")
|
||||
)
|
||||
|
||||
if args['train_from_scratch']:
|
||||
print('\nTraining model from scratch i.e. not using pretrained weights')
|
||||
model, params_train = du.load_model(args['model_path'], False)
|
||||
if args["train_from_scratch"]:
|
||||
print("\nTraining model from scratch i.e. not using pretrained weights")
|
||||
model, params_train = du.load_model(args["model_path"], False)
|
||||
else:
|
||||
model, params_train = du.load_model(args['model_path'], True)
|
||||
model.to(params['device'])
|
||||
model, params_train = du.load_model(args["model_path"], True)
|
||||
model.to(params["device"])
|
||||
|
||||
params['num_epochs'] = args['num_epochs']
|
||||
if args['op_model_name'] != '':
|
||||
params['model_file_name'] = args['op_model_name']
|
||||
classes_to_ignore = params['classes_to_ignore']+params['generic_class']
|
||||
params["num_epochs"] = args["num_epochs"]
|
||||
if args["op_model_name"] != "":
|
||||
params["model_file_name"] = args["op_model_name"]
|
||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||
|
||||
# save notes file
|
||||
params['notes'] = args['notes']
|
||||
if args['notes'] != '':
|
||||
tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes'])
|
||||
|
||||
params["notes"] = args["notes"]
|
||||
if args["notes"] != "":
|
||||
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
|
||||
|
||||
# load train annotations
|
||||
train_sets = []
|
||||
train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path']))
|
||||
params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])]
|
||||
train_sets.append(
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name, False, args["train_ann_path"], args["audio_path"]
|
||||
)
|
||||
)
|
||||
params["train_sets"] = [
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name,
|
||||
False,
|
||||
os.path.basename(args["train_ann_path"]),
|
||||
args["audio_path"],
|
||||
)
|
||||
]
|
||||
|
||||
print('\nTrain set:')
|
||||
data_train, params['class_names'], params['class_inv_freq'] = \
|
||||
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'])
|
||||
print('Number of files', len(data_train))
|
||||
print("\nTrain set:")
|
||||
(
|
||||
data_train,
|
||||
params["class_names"],
|
||||
params["class_inv_freq"],
|
||||
) = tu.load_set_of_anns(
|
||||
train_sets, classes_to_ignore, params["events_of_interest"]
|
||||
)
|
||||
print("Number of files", len(data_train))
|
||||
|
||||
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
|
||||
params['class_names_short'] = tu.get_short_class_names(params['class_names'])
|
||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
||||
params["class_names"]
|
||||
)
|
||||
params["class_names_short"] = tu.get_short_class_names(
|
||||
params["class_names"]
|
||||
)
|
||||
|
||||
# load test annotations
|
||||
test_sets = []
|
||||
test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path']))
|
||||
params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])]
|
||||
test_sets.append(
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name, True, args["test_ann_path"], args["audio_path"]
|
||||
)
|
||||
)
|
||||
params["test_sets"] = [
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name,
|
||||
True,
|
||||
os.path.basename(args["test_ann_path"]),
|
||||
args["audio_path"],
|
||||
)
|
||||
]
|
||||
|
||||
print('\nTest set:')
|
||||
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'])
|
||||
print('Number of files', len(data_test))
|
||||
print("\nTest set:")
|
||||
data_test, _, _ = tu.load_set_of_anns(
|
||||
test_sets, classes_to_ignore, params["events_of_interest"]
|
||||
)
|
||||
print("Number of files", len(data_test))
|
||||
|
||||
# train loader
|
||||
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
|
||||
shuffle=True, num_workers=params['num_workers'], pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=params["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=params["num_workers"],
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# test loader - batch size of one because of variable file length
|
||||
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
|
||||
shuffle=False, num_workers=params['num_workers'], pin_memory=True)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=params["num_workers"],
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
inputs_train = next(iter(train_loader))
|
||||
params['ip_height'] = inputs_train['spec'].shape[2]
|
||||
print('\ntrain batch size :', inputs_train['spec'].shape)
|
||||
params["ip_height"] = inputs_train["spec"].shape[2]
|
||||
print("\ntrain batch size :", inputs_train["spec"].shape)
|
||||
|
||||
assert(params_train['model_name'] == 'Net2DFast')
|
||||
print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n')
|
||||
assert params_train["model_name"] == "Net2DFast"
|
||||
print(
|
||||
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
|
||||
)
|
||||
|
||||
# set the number of output classes
|
||||
num_filts = model.conv_classes_op.in_channels
|
||||
k_size = model.conv_classes_op.kernel_size
|
||||
pad = model.conv_classes_op.padding
|
||||
model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad)
|
||||
model.conv_classes_op.to(params['device'])
|
||||
model.conv_classes_op = torch.nn.Conv2d(
|
||||
num_filts,
|
||||
len(params["class_names"]) + 1,
|
||||
kernel_size=k_size,
|
||||
padding=pad,
|
||||
)
|
||||
model.conv_classes_op.to(params["device"])
|
||||
|
||||
if args['finetune_only_last_layer']:
|
||||
print('\nOnly finetuning the final layers.\n')
|
||||
train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op']
|
||||
train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i]
|
||||
if args["finetune_only_last_layer"]:
|
||||
print("\nOnly finetuning the final layers.\n")
|
||||
train_layers_i = [
|
||||
"conv_classes",
|
||||
"conv_classes_op",
|
||||
"conv_size",
|
||||
"conv_size_op",
|
||||
]
|
||||
train_layers = [tt + ".weight" for tt in train_layers_i] + [
|
||||
tt + ".bias" for tt in train_layers_i
|
||||
]
|
||||
for name, param in model.named_parameters():
|
||||
if name in train_layers:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
|
||||
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
|
||||
if params['train_loss'] == 'mse':
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer, params["num_epochs"] * len(train_loader)
|
||||
)
|
||||
if params["train_loss"] == "mse":
|
||||
det_criterion = losses.mse_loss
|
||||
elif params['train_loss'] == 'focal':
|
||||
elif params["train_loss"] == "focal":
|
||||
det_criterion = losses.focal_loss
|
||||
|
||||
# plotting
|
||||
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
|
||||
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
|
||||
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
|
||||
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
|
||||
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
|
||||
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
|
||||
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
|
||||
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
|
||||
train_plt_ls = pu.LossPlotter(
|
||||
params["experiment"] + "train_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
["train_loss"],
|
||||
None,
|
||||
None,
|
||||
["epoch", "train_loss"],
|
||||
logy=True,
|
||||
)
|
||||
test_plt_ls = pu.LossPlotter(
|
||||
params["experiment"] + "test_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
["test_loss"],
|
||||
None,
|
||||
None,
|
||||
["epoch", "test_loss"],
|
||||
logy=True,
|
||||
)
|
||||
test_plt = pu.LossPlotter(
|
||||
params["experiment"] + "test.png",
|
||||
params["num_epochs"] + 1,
|
||||
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
|
||||
[0, 1],
|
||||
None,
|
||||
["epoch", ""],
|
||||
)
|
||||
test_plt_class = pu.LossPlotter(
|
||||
params["experiment"] + "test_avg_prec.png",
|
||||
params["num_epochs"] + 1,
|
||||
params["class_names_short"],
|
||||
[0, 1],
|
||||
params["class_names_short"],
|
||||
["epoch", "avg_prec"],
|
||||
)
|
||||
|
||||
# main train loop
|
||||
for epoch in range(0, params['num_epochs']+1):
|
||||
for epoch in range(0, params["num_epochs"] + 1):
|
||||
|
||||
train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
|
||||
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
|
||||
train_loss = tm.train(
|
||||
model,
|
||||
epoch,
|
||||
train_loader,
|
||||
det_criterion,
|
||||
optimizer,
|
||||
scheduler,
|
||||
params,
|
||||
)
|
||||
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
|
||||
|
||||
if epoch % params['num_eval_epochs'] == 0:
|
||||
if epoch % params["num_eval_epochs"] == 0:
|
||||
# detection accuracy on test set
|
||||
test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params)
|
||||
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
|
||||
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
|
||||
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
|
||||
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
|
||||
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
|
||||
test_res, test_loss = tm.test(
|
||||
model, epoch, test_loader, det_criterion, params
|
||||
)
|
||||
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
|
||||
test_plt.update_and_save(
|
||||
epoch,
|
||||
[
|
||||
test_res["avg_prec"],
|
||||
test_res["rec_at_x"],
|
||||
test_res["avg_prec_class"],
|
||||
test_res["file_acc"],
|
||||
test_res["top_class"]["avg_prec"],
|
||||
],
|
||||
)
|
||||
test_plt_class.update_and_save(
|
||||
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
|
||||
)
|
||||
pu.plot_pr_curve_class(
|
||||
params["experiment"], "test_pr", "test_pr", test_res
|
||||
)
|
||||
|
||||
# save finetuned model
|
||||
print('saving model to: ' + params['model_file_name'])
|
||||
op_state = {'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'params' : params}
|
||||
torch.save(op_state, params['model_file_name'])
|
||||
|
||||
print("saving model to: " + params["model_file_name"])
|
||||
op_state = {
|
||||
"epoch": epoch + 1,
|
||||
"state_dict": model.state_dict(),
|
||||
"params": params,
|
||||
}
|
||||
torch.save(op_state, params["model_file_name"])
|
||||
|
||||
# save an image with associated prediction for each batch in the test set
|
||||
if not args['do_not_save_images']:
|
||||
if not args["do_not_save_images"]:
|
||||
tm.save_images_batch(model, test_loader, params)
|
||||
|
@ -1,32 +1,33 @@
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join('..', '..'))
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.append(os.path.join("..", ".."))
|
||||
import bat_detect.train.train_utils as tu
|
||||
|
||||
|
||||
def print_dataset_stats(data, split_name, classes_to_ignore):
|
||||
|
||||
print('\nSplit:', split_name)
|
||||
print('Num files:', len(data))
|
||||
print("\nSplit:", split_name)
|
||||
print("Num files:", len(data))
|
||||
|
||||
class_cnts = {}
|
||||
for dd in data:
|
||||
for aa in dd['annotation']:
|
||||
if aa['class'] not in classes_to_ignore:
|
||||
if aa['class'] in class_cnts:
|
||||
class_cnts[aa['class']] += 1
|
||||
for aa in dd["annotation"]:
|
||||
if aa["class"] not in classes_to_ignore:
|
||||
if aa["class"] in class_cnts:
|
||||
class_cnts[aa["class"]] += 1
|
||||
else:
|
||||
class_cnts[aa['class']] = 1
|
||||
class_cnts[aa["class"]] = 1
|
||||
|
||||
if len(class_cnts) == 0:
|
||||
class_names = []
|
||||
else:
|
||||
class_names = np.sort([*class_cnts]).tolist()
|
||||
print('Class count:')
|
||||
print("Class count:")
|
||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||
|
||||
for ii, cc in enumerate(class_names):
|
||||
@ -41,111 +42,165 @@ def load_file_names(file_name):
|
||||
with open(file_name) as da:
|
||||
files = [line.rstrip() for line in da.readlines()]
|
||||
for ff in files:
|
||||
if ff.lower()[-3:] != 'wav':
|
||||
print('Error: Filenames need to end in .wav - ', ff)
|
||||
assert(False)
|
||||
if ff.lower()[-3:] != "wav":
|
||||
print("Error: Filenames need to end in .wav - ", ff)
|
||||
assert False
|
||||
else:
|
||||
print('Error: Input file not found - ', file_name)
|
||||
assert(False)
|
||||
print("Error: Input file not found - ", file_name)
|
||||
assert False
|
||||
|
||||
return files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
info_str = '\nBatDetect - Prepare Data for Finetuning\n'
|
||||
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
|
||||
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_name', type=str, help='Name to call your dataset')
|
||||
parser.add_argument('audio_dir', type=str, help='Input directory for audio')
|
||||
parser.add_argument('ann_dir', type=str, help='Input directory for where the audio annotations are stored')
|
||||
parser.add_argument('op_dir', type=str, help='Path where the train and test splits will be stored')
|
||||
parser.add_argument('--percent_val', type=float, default=0.20,
|
||||
help='Hold out this much data for validation. Should be number between 0 and 1')
|
||||
parser.add_argument('--rand_seed', type=int, default=2001,
|
||||
help='Random seed used for creating the validation split')
|
||||
parser.add_argument('--train_file', type=str, default='',
|
||||
help='Text file where each line is a wav file in train split')
|
||||
parser.add_argument('--test_file', type=str, default='',
|
||||
help='Text file where each line is a wav file in test split')
|
||||
parser.add_argument('--input_class_names', type=str, default='',
|
||||
help='Specify names of classes that you want to change. Separate with ";"')
|
||||
parser.add_argument('--output_class_names', type=str, default='',
|
||||
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
||||
Separate with ";"')
|
||||
parser.add_argument(
|
||||
"dataset_name", type=str, help="Name to call your dataset"
|
||||
)
|
||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
||||
parser.add_argument(
|
||||
"ann_dir",
|
||||
type=str,
|
||||
help="Input directory for where the audio annotations are stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"op_dir",
|
||||
type=str,
|
||||
help="Path where the train and test splits will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percent_val",
|
||||
type=float,
|
||||
default=0.20,
|
||||
help="Hold out this much data for validation. Should be number between 0 and 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rand_seed",
|
||||
type=int,
|
||||
default=2001,
|
||||
help="Random seed used for creating the validation split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Text file where each line is a wav file in train split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Text file where each line is a wav file in test split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_class_names",
|
||||
type=str,
|
||||
default="",
|
||||
help='Specify names of classes that you want to change. Separate with ";"',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_class_names",
|
||||
type=str,
|
||||
default="",
|
||||
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
||||
Separate with ";"',
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
np.random.seed(args["rand_seed"])
|
||||
|
||||
np.random.seed(args['rand_seed'])
|
||||
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
|
||||
generic_class = ["Bat"]
|
||||
events_of_interest = ["Echolocation"]
|
||||
|
||||
classes_to_ignore = ['', ' ', 'Unknown', 'Not Bat']
|
||||
generic_class = ['Bat']
|
||||
events_of_interest = ['Echolocation']
|
||||
|
||||
if args['input_class_names'] != '' and args['output_class_names'] != '':
|
||||
if args["input_class_names"] != "" and args["output_class_names"] != "":
|
||||
# change the names of the classes
|
||||
ip_names = args['input_class_names'].split(';')
|
||||
op_names = args['output_class_names'].split(';')
|
||||
ip_names = args["input_class_names"].split(";")
|
||||
op_names = args["output_class_names"].split(";")
|
||||
name_dict = dict(zip(ip_names, op_names))
|
||||
else:
|
||||
name_dict = False
|
||||
|
||||
# load annotations
|
||||
data_all, _, _ = tu.load_set_of_anns({'ann_path': args['ann_dir'], 'wav_path': args['audio_dir']},
|
||||
classes_to_ignore, events_of_interest, False, False,
|
||||
list_of_anns=True, filter_issues=True, name_replace=name_dict)
|
||||
data_all, _, _ = tu.load_set_of_anns(
|
||||
{"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
|
||||
classes_to_ignore,
|
||||
events_of_interest,
|
||||
False,
|
||||
False,
|
||||
list_of_anns=True,
|
||||
filter_issues=True,
|
||||
name_replace=name_dict,
|
||||
)
|
||||
|
||||
print('Dataset name: ' + args['dataset_name'])
|
||||
print('Audio directory: ' + args['audio_dir'])
|
||||
print('Annotation directory: ' + args['ann_dir'])
|
||||
print('Ouput directory: ' + args['op_dir'])
|
||||
print('Num annotated files: ' + str(len(data_all)))
|
||||
print("Dataset name: " + args["dataset_name"])
|
||||
print("Audio directory: " + args["audio_dir"])
|
||||
print("Annotation directory: " + args["ann_dir"])
|
||||
print("Ouput directory: " + args["op_dir"])
|
||||
print("Num annotated files: " + str(len(data_all)))
|
||||
|
||||
if args['train_file'] != '' and args['test_file'] != '':
|
||||
if args["train_file"] != "" and args["test_file"] != "":
|
||||
# user has specifed the train / test split
|
||||
train_files = load_file_names(args['train_file'])
|
||||
test_files = load_file_names(args['test_file'])
|
||||
file_names_all = [dd['id'] for dd in data_all]
|
||||
train_inds = [file_names_all.index(ff) for ff in train_files if ff in file_names_all]
|
||||
test_inds = [file_names_all.index(ff) for ff in test_files if ff in file_names_all]
|
||||
train_files = load_file_names(args["train_file"])
|
||||
test_files = load_file_names(args["test_file"])
|
||||
file_names_all = [dd["id"] for dd in data_all]
|
||||
train_inds = [
|
||||
file_names_all.index(ff)
|
||||
for ff in train_files
|
||||
if ff in file_names_all
|
||||
]
|
||||
test_inds = [
|
||||
file_names_all.index(ff)
|
||||
for ff in test_files
|
||||
if ff in file_names_all
|
||||
]
|
||||
|
||||
else:
|
||||
# split the data into train and test at the file level
|
||||
num_exs = len(data_all)
|
||||
test_inds = np.random.choice(np.arange(num_exs), int(num_exs*args['percent_val']), replace=False)
|
||||
test_inds = np.random.choice(
|
||||
np.arange(num_exs),
|
||||
int(num_exs * args["percent_val"]),
|
||||
replace=False,
|
||||
)
|
||||
test_inds = np.sort(test_inds)
|
||||
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
|
||||
|
||||
data_train = [data_all[ii] for ii in train_inds]
|
||||
data_test = [data_all[ii] for ii in test_inds]
|
||||
|
||||
if not os.path.isdir(args['op_dir']):
|
||||
os.makedirs(args['op_dir'])
|
||||
op_name = os.path.join(args['op_dir'], args['dataset_name'])
|
||||
op_name_train = op_name + '_TRAIN.json'
|
||||
op_name_test = op_name + '_TEST.json'
|
||||
if not os.path.isdir(args["op_dir"]):
|
||||
os.makedirs(args["op_dir"])
|
||||
op_name = os.path.join(args["op_dir"], args["dataset_name"])
|
||||
op_name_train = op_name + "_TRAIN.json"
|
||||
op_name_test = op_name + "_TEST.json"
|
||||
|
||||
class_un_train = print_dataset_stats(data_train, 'Train', classes_to_ignore)
|
||||
class_un_test = print_dataset_stats(data_test, 'Test', classes_to_ignore)
|
||||
class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore)
|
||||
class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
|
||||
|
||||
if len(data_train) > 0 and len(data_test) > 0:
|
||||
if class_un_train != class_un_test:
|
||||
print('\nError: some classes are not in both the training and test sets.\
|
||||
\nTry a different random seed "--rand_seed".')
|
||||
print(
|
||||
'\nError: some classes are not in both the training and test sets.\
|
||||
\nTry a different random seed "--rand_seed".'
|
||||
)
|
||||
assert False
|
||||
|
||||
print('\n')
|
||||
print("\n")
|
||||
if len(data_train) == 0:
|
||||
print('No train annotations to save')
|
||||
print("No train annotations to save")
|
||||
else:
|
||||
print('Saving: ', op_name_train)
|
||||
with open(op_name_train, 'w') as da:
|
||||
print("Saving: ", op_name_train)
|
||||
with open(op_name_train, "w") as da:
|
||||
json.dump(data_train, da, indent=2)
|
||||
|
||||
if len(data_test) == 0:
|
||||
print('No test annotations to save')
|
||||
print("No test annotations to save")
|
||||
else:
|
||||
print('Saving: ', op_name_test)
|
||||
with open(op_name_test, 'w') as da:
|
||||
print("Saving: ", op_name_test)
|
||||
with open(op_name_test, "w") as da:
|
||||
json.dump(data_test, da, indent=2)
|
||||
|
@ -1,71 +1,144 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import copy
|
||||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import os
|
||||
|
||||
import sys
|
||||
sys.path.append(os.path.join('..', '..'))
|
||||
import bat_detect.utils.audio_utils as au
|
||||
from bat_detect.types import AnnotationGroup, HeatmapParameters
|
||||
|
||||
|
||||
def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
||||
def generate_gt_heatmaps(
|
||||
spec_op_shape: Tuple[int, int],
|
||||
sampling_rate: int,
|
||||
ann: AnnotationGroup,
|
||||
params: HeatmapParameters,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
|
||||
"""Generate ground truth heatmaps from annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec_op_shape : Tuple[int, int]
|
||||
Shape of the input spectrogram.
|
||||
sampling_rate : int
|
||||
Sampling rate of the input audio in Hz.
|
||||
ann : AnnotationGroup
|
||||
Dictionary containing the annotation information.
|
||||
params : HeatmapParameters
|
||||
Parameters controlling the generation of the heatmaps.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
y_2d_det : np.ndarray
|
||||
2D heatmap of the presence of an event.
|
||||
|
||||
y_2d_size : np.ndarray
|
||||
2D heatmap of the size of the bounding box associated to event.
|
||||
|
||||
y_2d_classes : np.ndarray
|
||||
3D array containing the ground-truth class probabilities for each
|
||||
pixel.
|
||||
|
||||
ann_aug : AnnotationGroup
|
||||
A dictionary containing the annotation information of the
|
||||
annotations that are within the input spectrogram, augmented with
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(params['class_names'])
|
||||
op_height = spec_op_shape[0]
|
||||
op_width = spec_op_shape[1]
|
||||
freq_per_bin = (params['max_freq'] - params['min_freq']) / op_height
|
||||
num_classes = len(params["class_names"])
|
||||
op_height = spec_op_shape[0]
|
||||
op_width = spec_op_shape[1]
|
||||
freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
|
||||
|
||||
# start and end times
|
||||
x_pos_start = au.time_to_x_coords(ann['start_times'], sampling_rate,
|
||||
params['fft_win_length'], params['fft_overlap'])
|
||||
x_pos_start = (params['resize_factor']*x_pos_start).astype(np.int)
|
||||
x_pos_end = au.time_to_x_coords(ann['end_times'], sampling_rate,
|
||||
params['fft_win_length'], params['fft_overlap'])
|
||||
x_pos_end = (params['resize_factor']*x_pos_end).astype(np.int)
|
||||
x_pos_start = au.time_to_x_coords(
|
||||
ann["start_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
|
||||
x_pos_end = au.time_to_x_coords(
|
||||
ann["end_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
|
||||
|
||||
# location on y axis i.e. frequency
|
||||
y_pos_low = (ann['low_freqs'] - params['min_freq']) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int)
|
||||
y_pos_high = (ann['high_freqs'] - params['min_freq']) / freq_per_bin
|
||||
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int)
|
||||
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = (y_pos_low - y_pos_high)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
|
||||
valid_inds = np.where((x_pos_start >= 0) & (x_pos_start < op_width) &
|
||||
(y_pos_low >= 0) & (y_pos_low < (op_height-1)))[0]
|
||||
# Only include annotations that are within the input spectrogram
|
||||
valid_inds = np.where(
|
||||
(x_pos_start >= 0)
|
||||
& (x_pos_start < op_width)
|
||||
& (y_pos_low >= 0)
|
||||
& (y_pos_low < (op_height - 1))
|
||||
)[0]
|
||||
|
||||
ann_aug = {}
|
||||
ann_aug['x_inds'] = x_pos_start[valid_inds]
|
||||
ann_aug['y_inds'] = y_pos_low[valid_inds]
|
||||
keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids']
|
||||
for kk in keys:
|
||||
ann_aug[kk] = ann[kk][valid_inds]
|
||||
ann_aug: AnnotationGroup = {
|
||||
"start_times": ann["start_times"][valid_inds],
|
||||
"end_times": ann["end_times"][valid_inds],
|
||||
"high_freqs": ann["high_freqs"][valid_inds],
|
||||
"low_freqs": ann["low_freqs"][valid_inds],
|
||||
"class_ids": ann["class_ids"][valid_inds],
|
||||
"individual_ids": ann["individual_ids"][valid_inds],
|
||||
}
|
||||
ann_aug["x_inds"] = x_pos_start[valid_inds]
|
||||
ann_aug["y_inds"] = y_pos_low[valid_inds]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann_aug[kk] = ann[kk][valid_inds]
|
||||
|
||||
# if the number of calls is only 1, then it is unique
|
||||
# TODO would be better if we found these unique calls at the merging stage
|
||||
if len(ann_aug['individual_ids']) == 1:
|
||||
ann_aug['individual_ids'][0] = 0
|
||||
if len(ann_aug["individual_ids"]) == 1:
|
||||
ann_aug["individual_ids"][0] = 0
|
||||
|
||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||
# num classes and "background" class
|
||||
y_2d_classes = np.zeros((num_classes+1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_classes: np.ndarray = np.zeros(
|
||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||
)
|
||||
|
||||
# create 2D ground truth heatmaps
|
||||
for ii in valid_inds:
|
||||
draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'])
|
||||
#draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
draw_gaussian(
|
||||
y_2d_det[0, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||
|
||||
cls_id = ann['class_ids'][ii]
|
||||
cls_id = ann["class_ids"][ii]
|
||||
if cls_id > -1:
|
||||
draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'])
|
||||
#draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
draw_gaussian(
|
||||
y_2d_classes[cls_id, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
|
||||
# be careful as this will have a 1.0 places where we have event but dont know gt class
|
||||
# this will be masked in training anyway
|
||||
@ -96,20 +169,24 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None):
|
||||
x = np.arange(0, size, 1, np.float32)
|
||||
y = x[:, np.newaxis]
|
||||
x0 = y0 = size // 2
|
||||
#g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
g = np.exp(- ((x - x0) ** 2)/(2 * sigmax ** 2) - ((y - y0) ** 2)/(2 * sigmay ** 2))
|
||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
g = np.exp(
|
||||
-((x - x0) ** 2) / (2 * sigmax**2)
|
||||
- ((y - y0) ** 2) / (2 * sigmay**2)
|
||||
)
|
||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
||||
img_x = max(0, ul[0]), min(br[0], h)
|
||||
img_y = max(0, ul[1]), min(br[1], w)
|
||||
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
|
||||
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]],
|
||||
g[g_y[0]:g_y[1], g_x[0]:g_x[1]])
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
|
||||
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def pad_aray(ip_array, pad_size):
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int)*-1))
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
|
||||
|
||||
|
||||
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
||||
@ -121,24 +198,37 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
||||
if return_spec_for_viz:
|
||||
assert False
|
||||
|
||||
delta = params['stretch_squeeze_delta']
|
||||
delta = params["stretch_squeeze_delta"]
|
||||
op_size = (spec.shape[1], spec.shape[2])
|
||||
resize_fract_r = np.random.rand()*delta*2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2]*resize_fract_r)
|
||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||
if resize_amt >= spec.shape[2]:
|
||||
spec_r = torch.cat((spec, torch.zeros((1, spec.shape[1], resize_amt-spec.shape[2]), dtype=spec.dtype)), 2)
|
||||
spec_r = torch.cat(
|
||||
(
|
||||
spec,
|
||||
torch.zeros(
|
||||
(1, spec.shape[1], resize_amt - spec.shape[2]),
|
||||
dtype=spec.dtype,
|
||||
),
|
||||
),
|
||||
2,
|
||||
)
|
||||
else:
|
||||
spec_r = spec[:, :, :resize_amt]
|
||||
spec = F.interpolate(spec_r.unsqueeze(0), size=op_size, mode='bilinear', align_corners=False).squeeze(0)
|
||||
ann['start_times'] *= (1.0/resize_fract_r)
|
||||
ann['end_times'] *= (1.0/resize_fract_r)
|
||||
spec = F.interpolate(
|
||||
spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
ann["start_times"] *= 1.0 / resize_fract_r
|
||||
ann["end_times"] *= 1.0 / resize_fract_r
|
||||
return spec
|
||||
|
||||
|
||||
def mask_time_aug(spec, params):
|
||||
# Mask out a random block of time - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
|
||||
fm = torchaudio.transforms.TimeMasking(int(spec.shape[1]*params['mask_max_time_perc']))
|
||||
fm = torchaudio.transforms.TimeMasking(
|
||||
int(spec.shape[1] * params["mask_max_time_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
@ -147,40 +237,65 @@ def mask_time_aug(spec, params):
|
||||
def mask_freq_aug(spec, params):
|
||||
# Mask out a random frequncy range - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
||||
fm = torchaudio.transforms.FrequencyMasking(int(spec.shape[1]*params['mask_max_freq_perc']))
|
||||
fm = torchaudio.transforms.FrequencyMasking(
|
||||
int(spec.shape[1] * params["mask_max_freq_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def scale_vol_aug(spec, params):
|
||||
return spec * np.random.random()*params['spec_amp_scaling']
|
||||
return spec * np.random.random() * params["spec_amp_scaling"]
|
||||
|
||||
|
||||
def echo_aug(audio, sampling_rate, params):
|
||||
sample_offset = int(params['echo_max_delay']*np.random.random()*sampling_rate) + 1
|
||||
audio[:-sample_offset] += np.random.random()*audio[sample_offset:]
|
||||
sample_offset = (
|
||||
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||
return audio
|
||||
|
||||
|
||||
def resample_aug(audio, sampling_rate, params):
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(params['aug_sampling_rates'])
|
||||
audio = librosa.resample(audio, sampling_rate_old, sampling_rate, res_type='polyphase')
|
||||
sampling_rate = np.random.choice(params["aug_sampling_rates"])
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
|
||||
params['fft_overlap'], params['resize_factor'],
|
||||
params['spec_divide_factor'], params['spec_train_width'])
|
||||
audio = au.pad_audio(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
params["resize_factor"],
|
||||
params["spec_divide_factor"],
|
||||
params["spec_train_width"],
|
||||
)
|
||||
duration = audio.shape[0] / float(sampling_rate)
|
||||
return audio, sampling_rate, duration
|
||||
|
||||
|
||||
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
if sampling_rate != sampling_rate2:
|
||||
audio2 = librosa.resample(audio2, sampling_rate2, sampling_rate, res_type='polyphase')
|
||||
audio2 = librosa.resample(
|
||||
audio2,
|
||||
orig_sr=sampling_rate2,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
sampling_rate2 = sampling_rate
|
||||
if audio2.shape[0] < num_samples:
|
||||
audio2 = np.hstack((audio2, np.zeros((num_samples-audio2.shape[0]), dtype=audio2.dtype)))
|
||||
audio2 = np.hstack(
|
||||
(
|
||||
audio2,
|
||||
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
|
||||
)
|
||||
)
|
||||
elif audio2.shape[0] > num_samples:
|
||||
audio2 = audio2[:num_samples]
|
||||
return audio2, sampling_rate2
|
||||
@ -189,26 +304,32 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
|
||||
# resample so they are the same
|
||||
audio2, sampling_rate2 = resample_audio(audio.shape[0], sampling_rate, audio2, sampling_rate2)
|
||||
audio2, sampling_rate2 = resample_audio(
|
||||
audio.shape[0], sampling_rate, audio2, sampling_rate2
|
||||
)
|
||||
|
||||
# # set mean and std to be the same
|
||||
# audio2 = (audio2 - audio2.mean())
|
||||
# audio2 = (audio2/audio2.std())*audio.std()
|
||||
# audio2 = audio2 + audio.mean()
|
||||
|
||||
if ann['annotated'] and (ann2['annotated']) and \
|
||||
(sampling_rate2 == sampling_rate) and (audio.shape[0] == audio2.shape[0]):
|
||||
comb_weight = 0.3 + np.random.random()*0.4
|
||||
audio = comb_weight*audio + (1-comb_weight)*audio2
|
||||
inds = np.argsort(np.hstack((ann['start_times'], ann2['start_times'])))
|
||||
if (
|
||||
ann["annotated"]
|
||||
and (ann2["annotated"])
|
||||
and (sampling_rate2 == sampling_rate)
|
||||
and (audio.shape[0] == audio2.shape[0])
|
||||
):
|
||||
comb_weight = 0.3 + np.random.random() * 0.4
|
||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
||||
for kk in ann.keys():
|
||||
|
||||
# when combining calls from different files, assume they come from different individuals
|
||||
if kk == 'individual_ids':
|
||||
if (ann[kk]>-1).sum() > 0:
|
||||
ann2[kk][ann2[kk]>-1] += np.max(ann[kk][ann[kk]>-1]) + 1
|
||||
if kk == "individual_ids":
|
||||
if (ann[kk] > -1).sum() > 0:
|
||||
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
|
||||
if (kk != 'class_id_file') and (kk != 'annotated'):
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
||||
|
||||
return audio, ann
|
||||
@ -227,53 +348,70 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
|
||||
# filter out unused annotation here
|
||||
filtered_annotations = []
|
||||
for ii, aa in enumerate(dd['annotation']):
|
||||
for ii, aa in enumerate(dd["annotation"]):
|
||||
|
||||
if 'individual' in aa.keys():
|
||||
aa['individual'] = int(aa['individual'])
|
||||
if "individual" in aa.keys():
|
||||
aa["individual"] = int(aa["individual"])
|
||||
|
||||
# if only one call labeled it has to be from the same individual
|
||||
if len(dd['annotation']) == 1:
|
||||
aa['individual'] = 0
|
||||
if len(dd["annotation"]) == 1:
|
||||
aa["individual"] = 0
|
||||
|
||||
# convert class name into class label
|
||||
if aa['class'] in self.params['class_names']:
|
||||
aa['class_id'] = self.params['class_names'].index(aa['class'])
|
||||
if aa["class"] in self.params["class_names"]:
|
||||
aa["class_id"] = self.params["class_names"].index(
|
||||
aa["class"]
|
||||
)
|
||||
else:
|
||||
aa['class_id'] = -1
|
||||
aa["class_id"] = -1
|
||||
|
||||
if aa['class'] not in self.params['classes_to_ignore']:
|
||||
if aa["class"] not in self.params["classes_to_ignore"]:
|
||||
filtered_annotations.append(aa)
|
||||
|
||||
dd['annotation'] = filtered_annotations
|
||||
dd['start_times'] = np.array([aa['start_time'] for aa in dd['annotation']])
|
||||
dd['end_times'] = np.array([aa['end_time'] for aa in dd['annotation']])
|
||||
dd['high_freqs'] = np.array([float(aa['high_freq']) for aa in dd['annotation']])
|
||||
dd['low_freqs'] = np.array([float(aa['low_freq']) for aa in dd['annotation']])
|
||||
dd['class_ids'] = np.array([aa['class_id'] for aa in dd['annotation']]).astype(np.int)
|
||||
dd['individual_ids'] = np.array([aa['individual'] for aa in dd['annotation']]).astype(np.int)
|
||||
dd["annotation"] = filtered_annotations
|
||||
dd["start_times"] = np.array(
|
||||
[aa["start_time"] for aa in dd["annotation"]]
|
||||
)
|
||||
dd["end_times"] = np.array(
|
||||
[aa["end_time"] for aa in dd["annotation"]]
|
||||
)
|
||||
dd["high_freqs"] = np.array(
|
||||
[float(aa["high_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["low_freqs"] = np.array(
|
||||
[float(aa["low_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["class_ids"] = np.array(
|
||||
[aa["class_id"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
dd["individual_ids"] = np.array(
|
||||
[aa["individual"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
|
||||
# file level class name
|
||||
dd['class_id_file'] = -1
|
||||
if 'class_name' in dd.keys():
|
||||
if dd['class_name'] in self.params['class_names']:
|
||||
dd['class_id_file'] = self.params['class_names'].index(dd['class_name'])
|
||||
dd["class_id_file"] = -1
|
||||
if "class_name" in dd.keys():
|
||||
if dd["class_name"] in self.params["class_names"]:
|
||||
dd["class_id_file"] = self.params["class_names"].index(
|
||||
dd["class_name"]
|
||||
)
|
||||
|
||||
self.data_anns.append(dd)
|
||||
|
||||
ann_cnt = [len(aa['annotation']) for aa in self.data_anns]
|
||||
self.max_num_anns = 2*np.max(ann_cnt) # x2 because we may be combining files during training
|
||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||
self.max_num_anns = 2 * np.max(
|
||||
ann_cnt
|
||||
) # x2 because we may be combining files during training
|
||||
|
||||
print('\n')
|
||||
print("\n")
|
||||
if dataset_name is not None:
|
||||
print('Dataset : ' + dataset_name)
|
||||
print("Dataset : " + dataset_name)
|
||||
if self.is_train:
|
||||
print('Split type : train')
|
||||
print("Split type : train")
|
||||
else:
|
||||
print('Split type : test')
|
||||
print('Num files : ' + str(len(self.data_anns)))
|
||||
print('Num calls : ' + str(np.sum(ann_cnt)))
|
||||
|
||||
print("Split type : test")
|
||||
print("Num files : " + str(len(self.data_anns)))
|
||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
||||
|
||||
def get_file_and_anns(self, index=None):
|
||||
|
||||
@ -281,110 +419,169 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
if index == None:
|
||||
index = np.random.randint(0, len(self.data_anns))
|
||||
|
||||
audio_file = self.data_anns[index]['file_path']
|
||||
sampling_rate, audio_raw = au.load_audio_file(audio_file, self.data_anns[index]['time_exp'],
|
||||
self.params['target_samp_rate'], self.params['scale_raw_audio'])
|
||||
audio_file = self.data_anns[index]["file_path"]
|
||||
sampling_rate, audio_raw = au.load_audio(
|
||||
audio_file,
|
||||
self.data_anns[index]["time_exp"],
|
||||
self.params["target_samp_rate"],
|
||||
self.params["scale_raw_audio"],
|
||||
)
|
||||
|
||||
# copy annotation
|
||||
ann = {}
|
||||
ann['annotated'] = self.data_anns[index]['annotated']
|
||||
ann['class_id_file'] = self.data_anns[index]['class_id_file']
|
||||
keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids']
|
||||
ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
keys = [
|
||||
"start_times",
|
||||
"end_times",
|
||||
"high_freqs",
|
||||
"low_freqs",
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
]
|
||||
for kk in keys:
|
||||
ann[kk] = self.data_anns[index][kk].copy()
|
||||
|
||||
# if train then grab a random crop
|
||||
if self.is_train:
|
||||
nfft = int(self.params['fft_win_length']*sampling_rate)
|
||||
noverlap = int(self.params['fft_overlap']*nfft)
|
||||
length_samples = self.params['spec_train_width']*(nfft - noverlap) + noverlap
|
||||
nfft = int(self.params["fft_win_length"] * sampling_rate)
|
||||
noverlap = int(self.params["fft_overlap"] * nfft)
|
||||
length_samples = (
|
||||
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
|
||||
)
|
||||
|
||||
if audio_raw.shape[0] - length_samples > 0:
|
||||
sample_crop = np.random.randint(audio_raw.shape[0] - length_samples)
|
||||
sample_crop = np.random.randint(
|
||||
audio_raw.shape[0] - length_samples
|
||||
)
|
||||
else:
|
||||
sample_crop = 0
|
||||
audio_raw = audio_raw[sample_crop:sample_crop+length_samples]
|
||||
ann['start_times'] = ann['start_times'] - sample_crop/float(sampling_rate)
|
||||
ann['end_times'] = ann['end_times'] - sample_crop/float(sampling_rate)
|
||||
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
|
||||
ann["start_times"] = ann["start_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
ann["end_times"] = ann["end_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
|
||||
# pad audio
|
||||
if self.is_train:
|
||||
op_spec_target_size = self.params['spec_train_width']
|
||||
op_spec_target_size = self.params["spec_train_width"]
|
||||
else:
|
||||
op_spec_target_size = None
|
||||
audio_raw = au.pad_audio(audio_raw, sampling_rate, self.params['fft_win_length'],
|
||||
self.params['fft_overlap'], self.params['resize_factor'],
|
||||
self.params['spec_divide_factor'], op_spec_target_size)
|
||||
audio_raw = au.pad_audio(
|
||||
audio_raw,
|
||||
sampling_rate,
|
||||
self.params["fft_win_length"],
|
||||
self.params["fft_overlap"],
|
||||
self.params["resize_factor"],
|
||||
self.params["spec_divide_factor"],
|
||||
op_spec_target_size,
|
||||
)
|
||||
duration = audio_raw.shape[0] / float(sampling_rate)
|
||||
|
||||
# sort based on time
|
||||
inds = np.argsort(ann['start_times'])
|
||||
inds = np.argsort(ann["start_times"])
|
||||
for kk in ann.keys():
|
||||
if (kk != 'class_id_file') and (kk != 'annotated'):
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = ann[kk][inds]
|
||||
|
||||
return audio_raw, sampling_rate, duration, ann
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# load audio file
|
||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
||||
|
||||
# augment on raw audio
|
||||
if self.is_train and self.params['augment_at_train']:
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
# augment - combine with random audio file
|
||||
if self.params['augment_at_train_combine'] and np.random.random() < self.params['aug_prob']:
|
||||
audio2, sampling_rate2, duration2, ann2 = self.get_file_and_anns()
|
||||
audio, ann = combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2)
|
||||
if (
|
||||
self.params["augment_at_train_combine"]
|
||||
and np.random.random() < self.params["aug_prob"]
|
||||
):
|
||||
(
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
duration2,
|
||||
ann2,
|
||||
) = self.get_file_and_anns()
|
||||
audio, ann = combine_audio_aug(
|
||||
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
|
||||
)
|
||||
|
||||
# simulate echo by adding delayed copy of the file
|
||||
if np.random.random() < self.params['aug_prob']:
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
audio = echo_aug(audio, sampling_rate, self.params)
|
||||
|
||||
# resample the audio
|
||||
#if np.random.random() < self.params['aug_prob']:
|
||||
# if np.random.random() < self.params['aug_prob']:
|
||||
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
|
||||
|
||||
# create spectrogram
|
||||
spec, spec_for_viz = au.generate_spectrogram(audio, sampling_rate, self.params, self.return_spec_for_viz)
|
||||
rsf = self.params['resize_factor']
|
||||
spec_op_shape = (int(self.params['spec_height']*rsf), int(spec.shape[1]*rsf))
|
||||
spec, spec_for_viz = au.generate_spectrogram(
|
||||
audio, sampling_rate, self.params, self.return_spec_for_viz
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(self.params["spec_height"] * rsf),
|
||||
int(spec.shape[1] * rsf),
|
||||
)
|
||||
|
||||
# resize the spec
|
||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
||||
spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False).squeeze(0)
|
||||
spec = F.interpolate(
|
||||
spec, size=spec_op_shape, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# augment spectrogram
|
||||
if self.is_train and self.params['augment_at_train']:
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
|
||||
if np.random.random() < self.params['aug_prob']:
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = scale_vol_aug(spec, self.params)
|
||||
|
||||
if np.random.random() < self.params['aug_prob']:
|
||||
spec = warp_spec_aug(spec, ann, self.return_spec_for_viz, self.params)
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = warp_spec_aug(
|
||||
spec, ann, self.return_spec_for_viz, self.params
|
||||
)
|
||||
|
||||
if np.random.random() < self.params['aug_prob']:
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_time_aug(spec, self.params)
|
||||
|
||||
if np.random.random() < self.params['aug_prob']:
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_freq_aug(spec, self.params)
|
||||
|
||||
outputs = {}
|
||||
outputs['spec'] = spec
|
||||
outputs["spec"] = spec
|
||||
if self.return_spec_for_viz:
|
||||
outputs['spec_for_viz'] = torch.from_numpy(spec_for_viz).unsqueeze(0)
|
||||
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
|
||||
0
|
||||
)
|
||||
|
||||
# create ground truth heatmaps
|
||||
outputs['y_2d_det'], outputs['y_2d_size'], outputs['y_2d_classes'], ann_aug =\
|
||||
generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
|
||||
(
|
||||
outputs["y_2d_det"],
|
||||
outputs["y_2d_size"],
|
||||
outputs["y_2d_classes"],
|
||||
ann_aug,
|
||||
) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
|
||||
|
||||
# hack to get around requirement that all vectors are the same length in
|
||||
# the output batch
|
||||
pad_size = self.max_num_anns-len(ann_aug['individual_ids'])
|
||||
outputs['is_valid'] = pad_aray(np.ones(len(ann_aug['individual_ids'])), pad_size)
|
||||
keys = ['class_ids', 'individual_ids', 'x_inds', 'y_inds',
|
||||
'start_times', 'end_times', 'low_freqs', 'high_freqs']
|
||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
||||
outputs["is_valid"] = pad_aray(
|
||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
||||
)
|
||||
keys = [
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
"x_inds",
|
||||
"y_inds",
|
||||
"start_times",
|
||||
"end_times",
|
||||
"low_freqs",
|
||||
"high_freqs",
|
||||
]
|
||||
for kk in keys:
|
||||
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
|
||||
|
||||
@ -394,14 +591,13 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
outputs[kk] = torch.from_numpy(outputs[kk])
|
||||
|
||||
# scalars
|
||||
outputs['class_id_file'] = ann['class_id_file']
|
||||
outputs['annotated'] = ann['annotated']
|
||||
outputs['duration'] = duration
|
||||
outputs['sampling_rate'] = sampling_rate
|
||||
outputs['file_id'] = index
|
||||
outputs["class_id_file"] = ann["class_id_file"]
|
||||
outputs["annotated"] = ann["annotated"]
|
||||
outputs["duration"] = duration
|
||||
outputs["sampling_rate"] = sampling_rate
|
||||
outputs["file_id"] = index
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_anns)
|
||||
|
@ -1,6 +1,10 @@
|
||||
import numpy as np
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
from sklearn.metrics import accuracy_score, balanced_accuracy_score
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
auc,
|
||||
balanced_accuracy_score,
|
||||
roc_curve,
|
||||
)
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
@ -13,8 +17,11 @@ def compute_error_auc(op_str, gt, pred, prob):
|
||||
fpr, tpr, thresholds = roc_curve(gt, pred)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
print(op_str + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc))
|
||||
#return class_acc, roc_auc
|
||||
print(
|
||||
op_str
|
||||
+ ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc)
|
||||
)
|
||||
# return class_acc, roc_auc
|
||||
|
||||
|
||||
def calc_average_precision(recall, precision):
|
||||
@ -25,10 +32,10 @@ def calc_average_precision(recall, precision):
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0]-2, -1,-1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii+1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0]+1
|
||||
ave_prec = ((mrec[inds] - mrec[inds-1])*mprec[inds]).sum()
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return float(ave_prec)
|
||||
|
||||
@ -37,7 +44,7 @@ def calc_recall_at_x(recall, precision, x=0.95):
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
inds = np.where(precision[::-1]>x)[0]
|
||||
inds = np.where(precision[::-1] > x)[0]
|
||||
if len(inds) > 0:
|
||||
return float(recall[::-1][inds[0]])
|
||||
else:
|
||||
@ -51,7 +58,15 @@ def compute_affinity_1d(pred_box, gt_boxes, threshold):
|
||||
return valid_detection, np.argmin(score)
|
||||
|
||||
|
||||
def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, threshold, ignore_start_end):
|
||||
def compute_pre_rec(
|
||||
gts,
|
||||
preds,
|
||||
eval_mode,
|
||||
class_of_interest,
|
||||
num_classes,
|
||||
threshold,
|
||||
ignore_start_end,
|
||||
):
|
||||
"""
|
||||
Computes precision and recall. Assumes that each file has been exhaustively
|
||||
annotated. Will not count predicted detection with a start time that is within
|
||||
@ -78,26 +93,40 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
|
||||
for pid, pp in enumerate(preds):
|
||||
|
||||
# filter predicted calls that are too near the start or end of the file
|
||||
file_dur = gts[pid]['duration']
|
||||
valid_inds = (pp['start_times'] >= ignore_start_end) & (pp['start_times'] <= (file_dur - ignore_start_end))
|
||||
file_dur = gts[pid]["duration"]
|
||||
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
||||
pp["start_times"] <= (file_dur - ignore_start_end)
|
||||
)
|
||||
|
||||
pred_boxes.append(np.vstack((pp['start_times'][valid_inds], pp['end_times'][valid_inds],
|
||||
pp['low_freqs'][valid_inds], pp['high_freqs'][valid_inds])).T)
|
||||
pred_boxes.append(
|
||||
np.vstack(
|
||||
(
|
||||
pp["start_times"][valid_inds],
|
||||
pp["end_times"][valid_inds],
|
||||
pp["low_freqs"][valid_inds],
|
||||
pp["high_freqs"][valid_inds],
|
||||
)
|
||||
).T
|
||||
)
|
||||
|
||||
if eval_mode == 'detection':
|
||||
if eval_mode == "detection":
|
||||
# overall detection
|
||||
confidence.append(pp['det_probs'][valid_inds])
|
||||
elif eval_mode == 'per_class':
|
||||
confidence.append(pp["det_probs"][valid_inds])
|
||||
elif eval_mode == "per_class":
|
||||
# per class
|
||||
confidence.append(pp['class_probs'].T[valid_inds, class_of_interest])
|
||||
elif eval_mode == 'top_class':
|
||||
confidence.append(
|
||||
pp["class_probs"].T[valid_inds, class_of_interest]
|
||||
)
|
||||
elif eval_mode == "top_class":
|
||||
# per class - note that sometimes 'class_probs' can be num_classes+1 in size
|
||||
top_class = np.argmax(pp['class_probs'].T[valid_inds, :num_classes], 1)
|
||||
confidence.append(pp['class_probs'].T[valid_inds, top_class])
|
||||
top_class = np.argmax(
|
||||
pp["class_probs"].T[valid_inds, :num_classes], 1
|
||||
)
|
||||
confidence.append(pp["class_probs"].T[valid_inds, top_class])
|
||||
pred_class.append(top_class)
|
||||
|
||||
# be careful, assuming the order in the list is same as GT
|
||||
file_ids.append([pid]*valid_inds.sum())
|
||||
file_ids.append([pid] * valid_inds.sum())
|
||||
|
||||
confidence = np.hstack(confidence)
|
||||
file_ids = np.hstack(file_ids).astype(np.int)
|
||||
@ -105,7 +134,6 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
|
||||
if len(pred_class) > 0:
|
||||
pred_class = np.hstack(pred_class)
|
||||
|
||||
|
||||
# extract relevant ground truth boxes
|
||||
gt_boxes = []
|
||||
gt_assigned = []
|
||||
@ -115,32 +143,42 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
|
||||
for gg in gts:
|
||||
|
||||
# filter ground truth calls that are too near the start or end of the file
|
||||
file_dur = gg['duration']
|
||||
valid_inds = (gg['start_times'] >= ignore_start_end) & (gg['start_times'] <= (file_dur - ignore_start_end))
|
||||
file_dur = gg["duration"]
|
||||
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
||||
gg["start_times"] <= (file_dur - ignore_start_end)
|
||||
)
|
||||
|
||||
# note, files with the incorrect duration will cause a problem
|
||||
if (gg['start_times'] > file_dur).sum() > 0:
|
||||
print('Error: file duration incorrect for', gg['id'])
|
||||
assert(False)
|
||||
if (gg["start_times"] > file_dur).sum() > 0:
|
||||
print("Error: file duration incorrect for", gg["id"])
|
||||
assert False
|
||||
|
||||
boxes = np.vstack((gg['start_times'][valid_inds], gg['end_times'][valid_inds],
|
||||
gg['low_freqs'][valid_inds], gg['high_freqs'][valid_inds])).T
|
||||
gen_class = gg['class_ids'][valid_inds] == -1
|
||||
class_ids = gg['class_ids'][valid_inds]
|
||||
boxes = np.vstack(
|
||||
(
|
||||
gg["start_times"][valid_inds],
|
||||
gg["end_times"][valid_inds],
|
||||
gg["low_freqs"][valid_inds],
|
||||
gg["high_freqs"][valid_inds],
|
||||
)
|
||||
).T
|
||||
gen_class = gg["class_ids"][valid_inds] == -1
|
||||
class_ids = gg["class_ids"][valid_inds]
|
||||
|
||||
# keep track of the number of relevant ground truth calls
|
||||
if eval_mode == 'detection':
|
||||
if eval_mode == "detection":
|
||||
# all valid ones
|
||||
num_positives += len(gg['start_times'][valid_inds])
|
||||
elif eval_mode == 'per_class':
|
||||
num_positives += len(gg["start_times"][valid_inds])
|
||||
elif eval_mode == "per_class":
|
||||
# all valid ones with class of interest
|
||||
num_positives += (gg['class_ids'][valid_inds] == class_of_interest).sum()
|
||||
elif eval_mode == 'top_class':
|
||||
num_positives += (
|
||||
gg["class_ids"][valid_inds] == class_of_interest
|
||||
).sum()
|
||||
elif eval_mode == "top_class":
|
||||
# all valid ones with non generic class
|
||||
num_positives += (gg['class_ids'][valid_inds] > -1).sum()
|
||||
num_positives += (gg["class_ids"][valid_inds] > -1).sum()
|
||||
|
||||
# find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1)
|
||||
if eval_mode == 'per_class':
|
||||
if eval_mode == "per_class":
|
||||
class_inds = (class_ids == class_of_interest) | (class_ids == -1)
|
||||
boxes = boxes[class_inds, :]
|
||||
gen_class = gen_class[class_inds]
|
||||
@ -151,25 +189,27 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
|
||||
gt_generic_class.append(gen_class)
|
||||
gt_class.append(class_ids)
|
||||
|
||||
|
||||
# loop through detections and keep track of those that have been assigned
|
||||
true_pos = np.zeros(confidence.shape[0])
|
||||
valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
|
||||
sorted_inds = np.argsort(confidence)[::-1] # sort high to low
|
||||
true_pos = np.zeros(confidence.shape[0])
|
||||
valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
|
||||
sorted_inds = np.argsort(confidence)[::-1] # sort high to low
|
||||
for ii, ind in enumerate(sorted_inds):
|
||||
gt_id = file_ids[ind]
|
||||
|
||||
valid_det = False
|
||||
if gt_boxes[gt_id].shape[0] > 0:
|
||||
# compute overlap
|
||||
valid_det, det_ind = compute_affinity_1d(pred_boxes[ind], gt_boxes[gt_id],
|
||||
threshold)
|
||||
valid_det, det_ind = compute_affinity_1d(
|
||||
pred_boxes[ind], gt_boxes[gt_id], threshold
|
||||
)
|
||||
|
||||
# valid detection that has not already been assigned
|
||||
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
||||
|
||||
count_as_true_pos = True
|
||||
if eval_mode == 'top_class' and (gt_class[gt_id][det_ind] != pred_class[ind]):
|
||||
if eval_mode == "top_class" and (
|
||||
gt_class[gt_id][det_ind] != pred_class[ind]
|
||||
):
|
||||
# needs to be the same class
|
||||
count_as_true_pos = False
|
||||
|
||||
@ -181,40 +221,43 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
|
||||
# if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True)
|
||||
# and eval_mode != 'detection', then ignore it
|
||||
if gt_generic_class[gt_id][det_ind]:
|
||||
if eval_mode == 'per_class' or eval_mode == 'top_class':
|
||||
if eval_mode == "per_class" or eval_mode == "top_class":
|
||||
valid_inds[ii] = False
|
||||
|
||||
|
||||
# store threshold values - used for plotting
|
||||
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
||||
thresholds = np.linspace(0.1, 0.9, 9)
|
||||
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
|
||||
for ii, tt in enumerate(thresholds):
|
||||
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
||||
thresholds_inds[thresholds_inds==0] = -1
|
||||
thresholds_inds[thresholds_inds == 0] = -1
|
||||
|
||||
# compute precision and recall
|
||||
true_pos = true_pos[valid_inds]
|
||||
false_pos_c = np.cumsum(1-true_pos)
|
||||
true_pos_c = np.cumsum(true_pos)
|
||||
true_pos = true_pos[valid_inds]
|
||||
false_pos_c = np.cumsum(1 - true_pos)
|
||||
true_pos_c = np.cumsum(true_pos)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(true_pos_c + false_pos_c, np.finfo(np.float64).eps)
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c, np.finfo(np.float64).eps
|
||||
)
|
||||
|
||||
results = {}
|
||||
results['recall'] = recall
|
||||
results['precision'] = precision
|
||||
results['num_gt'] = num_positives
|
||||
results["recall"] = recall
|
||||
results["precision"] = precision
|
||||
results["num_gt"] = num_positives
|
||||
|
||||
results['thresholds'] = thresholds
|
||||
results['thresholds_inds'] = thresholds_inds
|
||||
results["thresholds"] = thresholds
|
||||
results["thresholds_inds"] = thresholds_inds
|
||||
|
||||
if num_positives == 0:
|
||||
results['avg_prec'] = np.nan
|
||||
results['rec_at_x'] = np.nan
|
||||
results["avg_prec"] = np.nan
|
||||
results["rec_at_x"] = np.nan
|
||||
else:
|
||||
results['avg_prec'] = np.round(calc_average_precision(recall, precision), 5)
|
||||
results['rec_at_x'] = np.round(calc_recall_at_x(recall, precision), 5)
|
||||
results["avg_prec"] = np.round(
|
||||
calc_average_precision(recall, precision), 5
|
||||
)
|
||||
results["rec_at_x"] = np.round(calc_recall_at_x(recall, precision), 5)
|
||||
|
||||
return results
|
||||
|
||||
@ -230,19 +273,19 @@ def compute_file_accuracy_simple(gts, preds, num_classes):
|
||||
gt_valid = []
|
||||
pred_valid = []
|
||||
for ii in range(len(gts)):
|
||||
gt_class = np.unique(gts[ii]['class_ids'])
|
||||
gt_class = np.unique(gts[ii]["class_ids"])
|
||||
if len(gt_class) == 1 and gt_class[0] != -1:
|
||||
gt_valid.append(gt_class[0])
|
||||
pred = preds[ii]['class_probs'][:num_classes, :].T
|
||||
pred = preds[ii]["class_probs"][:num_classes, :].T
|
||||
pred_valid.append(np.argmax(pred.mean(0)))
|
||||
acc = (np.array(gt_valid) == np.array(pred_valid)).mean()
|
||||
|
||||
res = {}
|
||||
res['num_valid_files'] = len(gt_valid)
|
||||
res['num_total_files'] = len(gts)
|
||||
res['gt_valid_file'] = gt_valid
|
||||
res['pred_valid_file'] = pred_valid
|
||||
res['file_acc'] = np.round(acc, 5)
|
||||
res["num_valid_files"] = len(gt_valid)
|
||||
res["num_total_files"] = len(gts)
|
||||
res["gt_valid_file"] = gt_valid
|
||||
res["pred_valid_file"] = pred_valid
|
||||
res["file_acc"] = np.round(acc, 5)
|
||||
return res
|
||||
|
||||
|
||||
@ -256,12 +299,20 @@ def compute_file_accuracy(gts, preds, num_classes):
|
||||
|
||||
# compute min and max scoring range - then threshold
|
||||
min_val = 0
|
||||
mins = [pp['class_probs'].min() for pp in preds if pp['class_probs'].shape[1] > 0]
|
||||
mins = [
|
||||
pp["class_probs"].min()
|
||||
for pp in preds
|
||||
if pp["class_probs"].shape[1] > 0
|
||||
]
|
||||
if len(mins) > 0:
|
||||
min_val = np.min(mins)
|
||||
|
||||
max_val = 1.0
|
||||
maxes = [pp['class_probs'].max() for pp in preds if pp['class_probs'].shape[1] > 0]
|
||||
maxes = [
|
||||
pp["class_probs"].max()
|
||||
for pp in preds
|
||||
if pp["class_probs"].shape[1] > 0
|
||||
]
|
||||
if len(maxes) > 0:
|
||||
max_val = np.max(maxes)
|
||||
|
||||
@ -272,33 +323,37 @@ def compute_file_accuracy(gts, preds, num_classes):
|
||||
gt_valid = []
|
||||
pred_valid_all = []
|
||||
for ii in range(len(gts)):
|
||||
gt_class = np.unique(gts[ii]['class_ids'])
|
||||
gt_class = np.unique(gts[ii]["class_ids"])
|
||||
if len(gt_class) == 1 and gt_class[0] != -1:
|
||||
gt_valid.append(gt_class[0])
|
||||
pred = preds[ii]['class_probs'][:num_classes, :].T
|
||||
pred = preds[ii]["class_probs"][:num_classes, :].T
|
||||
p_class = np.zeros(len(thresh))
|
||||
for tt in range(len(thresh)):
|
||||
p_class[tt] = (pred*(pred>=thresh[tt])).sum(0).argmax()
|
||||
p_class[tt] = (pred * (pred >= thresh[tt])).sum(0).argmax()
|
||||
pred_valid_all.append(p_class)
|
||||
|
||||
# pick the result corresponding to the overall best threshold
|
||||
pred_valid_all = np.vstack(pred_valid_all)
|
||||
acc_per_thresh = (np.array(gt_valid)[..., np.newaxis] == pred_valid_all).mean(0)
|
||||
acc_per_thresh = (
|
||||
np.array(gt_valid)[..., np.newaxis] == pred_valid_all
|
||||
).mean(0)
|
||||
best_thresh = np.argmax(acc_per_thresh)
|
||||
best_acc = acc_per_thresh[best_thresh]
|
||||
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
|
||||
|
||||
res = {}
|
||||
res['num_valid_files'] = len(gt_valid)
|
||||
res['num_total_files'] = len(gts)
|
||||
res['gt_valid_file'] = gt_valid
|
||||
res['pred_valid_file'] = pred_valid
|
||||
res['file_acc'] = np.round(best_acc, 5)
|
||||
res["num_valid_files"] = len(gt_valid)
|
||||
res["num_total_files"] = len(gts)
|
||||
res["gt_valid_file"] = gt_valid
|
||||
res["pred_valid_file"] = pred_valid
|
||||
res["file_acc"] = np.round(best_acc, 5)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_start_end=0.0):
|
||||
def evaluate_predictions(
|
||||
gts, preds, class_names, detection_overlap, ignore_start_end=0.0
|
||||
):
|
||||
"""
|
||||
Computes metrics derived from the precision and recall.
|
||||
Assumes that gts and preds are both lists of the same lengths, with ground
|
||||
@ -307,24 +362,50 @@ def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_star
|
||||
Returns the overall detection results, and per class results
|
||||
"""
|
||||
|
||||
assert(len(gts) == len(preds))
|
||||
assert len(gts) == len(preds)
|
||||
num_classes = len(class_names)
|
||||
|
||||
# evaluate detection on its own i.e. ignoring class
|
||||
det_results = compute_pre_rec(gts, preds, 'detection', None, num_classes, detection_overlap, ignore_start_end)
|
||||
top_class = compute_pre_rec(gts, preds, 'top_class', None, num_classes, detection_overlap, ignore_start_end)
|
||||
det_results['top_class'] = top_class
|
||||
det_results = compute_pre_rec(
|
||||
gts,
|
||||
preds,
|
||||
"detection",
|
||||
None,
|
||||
num_classes,
|
||||
detection_overlap,
|
||||
ignore_start_end,
|
||||
)
|
||||
top_class = compute_pre_rec(
|
||||
gts,
|
||||
preds,
|
||||
"top_class",
|
||||
None,
|
||||
num_classes,
|
||||
detection_overlap,
|
||||
ignore_start_end,
|
||||
)
|
||||
det_results["top_class"] = top_class
|
||||
|
||||
# per class evaluation
|
||||
det_results['class_pr'] = []
|
||||
det_results["class_pr"] = []
|
||||
for cc in range(num_classes):
|
||||
res = compute_pre_rec(gts, preds, 'per_class', cc, num_classes, detection_overlap, ignore_start_end)
|
||||
res['name'] = class_names[cc]
|
||||
det_results['class_pr'].append(res)
|
||||
res = compute_pre_rec(
|
||||
gts,
|
||||
preds,
|
||||
"per_class",
|
||||
cc,
|
||||
num_classes,
|
||||
detection_overlap,
|
||||
ignore_start_end,
|
||||
)
|
||||
res["name"] = class_names[cc]
|
||||
det_results["class_pr"].append(res)
|
||||
|
||||
# ignores classes that are not present in the test set
|
||||
det_results['avg_prec_class'] = np.mean([rs['avg_prec'] for rs in det_results['class_pr'] if rs['num_gt'] > 0])
|
||||
det_results['avg_prec_class'] = np.round(det_results['avg_prec_class'], 5)
|
||||
det_results["avg_prec_class"] = np.mean(
|
||||
[rs["avg_prec"] for rs in det_results["class_pr"] if rs["num_gt"] > 0]
|
||||
)
|
||||
det_results["avg_prec_class"] = np.round(det_results["avg_prec_class"], 5)
|
||||
|
||||
# file level evaluation
|
||||
res_file = compute_file_accuracy(gts, preds, num_classes)
|
||||
|
@ -7,7 +7,9 @@ def bbox_size_loss(pred_size, gt_size):
|
||||
Bounding box size loss. Only compute loss where there is a bounding box.
|
||||
"""
|
||||
gt_size_mask = (gt_size > 0).float()
|
||||
return (F.l1_loss(pred_size*gt_size_mask, gt_size, reduction='sum') / (gt_size_mask.sum() + 1e-5))
|
||||
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
|
||||
gt_size_mask.sum() + 1e-5
|
||||
)
|
||||
|
||||
|
||||
def focal_loss(pred, gt, weights=None, valid_mask=None):
|
||||
@ -24,20 +26,25 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
|
||||
neg_inds = gt.lt(1).float()
|
||||
|
||||
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
|
||||
neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds
|
||||
neg_loss = (
|
||||
torch.log(1 - pred + eps)
|
||||
* torch.pow(pred, alpha)
|
||||
* torch.pow(1 - gt, beta)
|
||||
* neg_inds
|
||||
)
|
||||
|
||||
if weights is not None:
|
||||
pos_loss = pos_loss*weights
|
||||
#neg_loss = neg_loss*weights
|
||||
pos_loss = pos_loss * weights
|
||||
# neg_loss = neg_loss*weights
|
||||
|
||||
if valid_mask is not None:
|
||||
pos_loss = pos_loss*valid_mask
|
||||
neg_loss = neg_loss*valid_mask
|
||||
pos_loss = pos_loss * valid_mask
|
||||
neg_loss = neg_loss * valid_mask
|
||||
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
num_pos = pos_inds.float().sum()
|
||||
if num_pos == 0:
|
||||
loss = -neg_loss
|
||||
else:
|
||||
@ -47,10 +54,10 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
|
||||
|
||||
def mse_loss(pred, gt, weights=None, valid_mask=None):
|
||||
"""
|
||||
Mean squared error loss.
|
||||
Mean squared error loss.
|
||||
"""
|
||||
if valid_mask is None:
|
||||
op = ((gt-pred)**2).mean()
|
||||
op = ((gt - pred) ** 2).mean()
|
||||
else:
|
||||
op = (valid_mask*((gt-pred)**2)).sum() / valid_mask.sum()
|
||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
return op
|
||||
|
@ -1,32 +1,27 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import warnings
|
||||
|
||||
import sys
|
||||
sys.path.append(os.path.join('..', '..'))
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import bat_detect.detector.parameters as parameters
|
||||
import bat_detect.detector.models as models
|
||||
from bat_detect.detector import models
|
||||
from bat_detect.detector import parameters
|
||||
from bat_detect.train import losses
|
||||
import bat_detect.detector.post_process as pp
|
||||
import bat_detect.utils.plot_utils as pu
|
||||
|
||||
import bat_detect.train.audio_dataloader as adl
|
||||
import bat_detect.train.evaluate as evl
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.train.train_split as ts
|
||||
import bat_detect.train.losses as losses
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.utils.plot_utils as pu
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
|
||||
def save_images_batch(model, data_loader, params):
|
||||
print('\nsaving images ...')
|
||||
print("\nsaving images ...")
|
||||
|
||||
is_train_state = data_loader.dataset.is_train
|
||||
data_loader.dataset.is_train = False
|
||||
@ -36,67 +31,112 @@ def save_images_batch(model, data_loader, params):
|
||||
ind = 0 # first image in each batch
|
||||
with torch.no_grad():
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
data = inputs['spec'].to(params['device'])
|
||||
data = inputs["spec"].to(params["device"])
|
||||
outputs = model(data)
|
||||
|
||||
spec_viz = inputs['spec_for_viz'].data.cpu().numpy()
|
||||
orig_index = inputs['file_id'][ind]
|
||||
plot_title = data_loader.dataset.data_anns[orig_index]['id']
|
||||
op_file_name = params['op_im_dir_test'] + data_loader.dataset.data_anns[orig_index]['id'] + '.jpg'
|
||||
save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title)
|
||||
spec_viz = inputs["spec_for_viz"].data.cpu().numpy()
|
||||
orig_index = inputs["file_id"][ind]
|
||||
plot_title = data_loader.dataset.data_anns[orig_index]["id"]
|
||||
op_file_name = (
|
||||
params["op_im_dir_test"]
|
||||
+ data_loader.dataset.data_anns[orig_index]["id"]
|
||||
+ ".jpg"
|
||||
)
|
||||
save_image(
|
||||
spec_viz,
|
||||
outputs,
|
||||
ind,
|
||||
inputs,
|
||||
params,
|
||||
op_file_name,
|
||||
plot_title,
|
||||
)
|
||||
|
||||
data_loader.dataset.is_train = is_train_state
|
||||
data_loader.dataset.return_spec_for_viz = False
|
||||
|
||||
|
||||
def save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title):
|
||||
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float())
|
||||
pred_hm = outputs['pred_det'][ind, 0, :].data.cpu().numpy()
|
||||
def save_image(
|
||||
spec_viz, outputs, ind, inputs, params, op_file_name, plot_title
|
||||
):
|
||||
pred_nms, _ = pp.run_nms(outputs, params, inputs["sampling_rate"].float())
|
||||
pred_hm = outputs["pred_det"][ind, 0, :].data.cpu().numpy()
|
||||
spec_viz = spec_viz[ind, 0, :]
|
||||
gt = parse_gt_data(inputs)[ind]
|
||||
sampling_rate = inputs['sampling_rate'][ind].item()
|
||||
duration = inputs['duration'][ind].item()
|
||||
gt = parse_gt_data(inputs)[ind]
|
||||
sampling_rate = inputs["sampling_rate"][ind].item()
|
||||
duration = inputs["duration"][ind].item()
|
||||
|
||||
pu.plot_spec(spec_viz, sampling_rate, duration, gt, pred_nms[ind],
|
||||
params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False)
|
||||
pu.plot_spec(
|
||||
spec_viz,
|
||||
sampling_rate,
|
||||
duration,
|
||||
gt,
|
||||
pred_nms[ind],
|
||||
params,
|
||||
plot_title,
|
||||
op_file_name,
|
||||
pred_hm,
|
||||
plot_boxes=True,
|
||||
fixed_aspect=False,
|
||||
)
|
||||
|
||||
|
||||
def loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq):
|
||||
def loss_fun(
|
||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
||||
):
|
||||
|
||||
# detection loss
|
||||
loss = params['det_loss_weight']*det_criterion(outputs['pred_det'], gt_det)
|
||||
loss = params["det_loss_weight"] * det_criterion(
|
||||
outputs["pred_det"], gt_det
|
||||
)
|
||||
|
||||
# bounding box size loss
|
||||
loss += params['size_loss_weight']*losses.bbox_size_loss(outputs['pred_size'], gt_size)
|
||||
loss += params["size_loss_weight"] * losses.bbox_size_loss(
|
||||
outputs["pred_size"], gt_size
|
||||
)
|
||||
|
||||
# classification loss
|
||||
valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1)
|
||||
p_class = outputs['pred_class'][:, :-1, :]
|
||||
loss += params['class_loss_weight']*det_criterion(p_class, gt_class[:, :-1, :], valid_mask=valid_mask)
|
||||
p_class = outputs["pred_class"][:, :-1, :]
|
||||
loss += params["class_loss_weight"] * det_criterion(
|
||||
p_class, gt_class[:, :-1, :], valid_mask=valid_mask
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params):
|
||||
def train(
|
||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||
):
|
||||
|
||||
model.train()
|
||||
|
||||
train_loss = tu.AverageMeter()
|
||||
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device'])
|
||||
class_inv_freq = torch.from_numpy(
|
||||
np.array(params["class_inv_freq"], dtype=np.float32)
|
||||
).to(params["device"])
|
||||
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
|
||||
|
||||
print('\nEpoch', epoch)
|
||||
print("\nEpoch", epoch)
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
|
||||
data = inputs['spec'].to(params['device'])
|
||||
gt_det = inputs['y_2d_det'].to(params['device'])
|
||||
gt_size = inputs['y_2d_size'].to(params['device'])
|
||||
gt_class = inputs['y_2d_classes'].to(params['device'])
|
||||
data = inputs["spec"].to(params["device"])
|
||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||
gt_class = inputs["y_2d_classes"].to(params["device"])
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(data)
|
||||
|
||||
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq)
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
@ -104,13 +144,18 @@ def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||
scheduler.step()
|
||||
|
||||
if batch_idx % 50 == 0 and batch_idx != 0:
|
||||
print('[{}/{}]\tLoss: {:.4f}'.format(
|
||||
batch_idx * len(data), len(data_loader.dataset), train_loss.avg))
|
||||
print(
|
||||
"[{}/{}]\tLoss: {:.4f}".format(
|
||||
batch_idx * len(data),
|
||||
len(data_loader.dataset),
|
||||
train_loss.avg,
|
||||
)
|
||||
)
|
||||
|
||||
print('Train loss : {:.4f}'.format(train_loss.avg))
|
||||
print("Train loss : {:.4f}".format(train_loss.avg))
|
||||
|
||||
res = {}
|
||||
res['train_loss'] = float(train_loss.avg)
|
||||
res["train_loss"] = float(train_loss.avg)
|
||||
return res
|
||||
|
||||
|
||||
@ -120,16 +165,18 @@ def test(model, epoch, data_loader, det_criterion, params):
|
||||
ground_truths = []
|
||||
test_loss = tu.AverageMeter()
|
||||
|
||||
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device'])
|
||||
class_inv_freq = torch.from_numpy(
|
||||
np.array(params["class_inv_freq"], dtype=np.float32)
|
||||
).to(params["device"])
|
||||
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
|
||||
data = inputs['spec'].to(params['device'])
|
||||
gt_det = inputs['y_2d_det'].to(params['device'])
|
||||
gt_size = inputs['y_2d_size'].to(params['device'])
|
||||
gt_class = inputs['y_2d_classes'].to(params['device'])
|
||||
data = inputs["spec"].to(params["device"])
|
||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||
gt_class = inputs["y_2d_classes"].to(params["device"])
|
||||
|
||||
outputs = model(data)
|
||||
|
||||
@ -139,41 +186,79 @@ def test(model, epoch, data_loader, det_criterion, params):
|
||||
# for kk in ['pred_det', 'pred_size', 'pred_class']:
|
||||
# outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0)
|
||||
|
||||
if params['save_test_image_during_train'] and batch_idx == 0:
|
||||
if params["save_test_image_during_train"] and batch_idx == 0:
|
||||
# for visualization - save the first prediction
|
||||
ind = 0
|
||||
orig_index = inputs['file_id'][ind]
|
||||
plot_title = data_loader.dataset.data_anns[orig_index]['id']
|
||||
op_file_name = params['op_im_dir'] + str(orig_index.item()).zfill(4) + '_' + str(epoch).zfill(4) + '_pred.jpg'
|
||||
save_image(data, outputs, ind, inputs, params, op_file_name, plot_title)
|
||||
orig_index = inputs["file_id"][ind]
|
||||
plot_title = data_loader.dataset.data_anns[orig_index]["id"]
|
||||
op_file_name = (
|
||||
params["op_im_dir"]
|
||||
+ str(orig_index.item()).zfill(4)
|
||||
+ "_"
|
||||
+ str(epoch).zfill(4)
|
||||
+ "_pred.jpg"
|
||||
)
|
||||
save_image(
|
||||
data,
|
||||
outputs,
|
||||
ind,
|
||||
inputs,
|
||||
params,
|
||||
op_file_name,
|
||||
plot_title,
|
||||
)
|
||||
|
||||
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq)
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
test_loss.update(loss.item(), data.shape[0])
|
||||
|
||||
# do NMS
|
||||
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float())
|
||||
pred_nms, _ = pp.run_nms(
|
||||
outputs, params, inputs["sampling_rate"].float()
|
||||
)
|
||||
predictions.extend(pred_nms)
|
||||
|
||||
ground_truths.extend(parse_gt_data(inputs))
|
||||
|
||||
res_det = evl.evaluate_predictions(ground_truths, predictions, params['class_names'],
|
||||
params['detection_overlap'], params['ignore_start_end'])
|
||||
res_det = evl.evaluate_predictions(
|
||||
ground_truths,
|
||||
predictions,
|
||||
params["class_names"],
|
||||
params["detection_overlap"],
|
||||
params["ignore_start_end"],
|
||||
)
|
||||
|
||||
print('\nTest loss : {:.4f}'.format(test_loss.avg))
|
||||
print('Rec at 0.95 (det) : {:.4f}'.format(res_det['rec_at_x']))
|
||||
print('Avg prec (cls) : {:.4f}'.format(res_det['avg_prec']))
|
||||
print('File acc (cls) : {:.2f} - for {} out of {}'.format(res_det['file_acc'],
|
||||
res_det['num_valid_files'], res_det['num_total_files']))
|
||||
print('Cls Avg prec (cls) : {:.4f}'.format(res_det['avg_prec_class']))
|
||||
print("\nTest loss : {:.4f}".format(test_loss.avg))
|
||||
print("Rec at 0.95 (det) : {:.4f}".format(res_det["rec_at_x"]))
|
||||
print("Avg prec (cls) : {:.4f}".format(res_det["avg_prec"]))
|
||||
print(
|
||||
"File acc (cls) : {:.2f} - for {} out of {}".format(
|
||||
res_det["file_acc"],
|
||||
res_det["num_valid_files"],
|
||||
res_det["num_total_files"],
|
||||
)
|
||||
)
|
||||
print("Cls Avg prec (cls) : {:.4f}".format(res_det["avg_prec_class"]))
|
||||
|
||||
print('\nPer class average precision')
|
||||
str_len = np.max([len(rs['name']) for rs in res_det['class_pr']]) + 5
|
||||
for cc, rs in enumerate(res_det['class_pr']):
|
||||
if rs['num_gt'] > 0:
|
||||
print(str(cc).ljust(5) + rs['name'].ljust(str_len) + '{:.4f}'.format(rs['avg_prec']))
|
||||
print("\nPer class average precision")
|
||||
str_len = np.max([len(rs["name"]) for rs in res_det["class_pr"]]) + 5
|
||||
for cc, rs in enumerate(res_det["class_pr"]):
|
||||
if rs["num_gt"] > 0:
|
||||
print(
|
||||
str(cc).ljust(5)
|
||||
+ rs["name"].ljust(str_len)
|
||||
+ "{:.4f}".format(rs["avg_prec"])
|
||||
)
|
||||
|
||||
res = {}
|
||||
res['test_loss'] = float(test_loss.avg)
|
||||
res["test_loss"] = float(test_loss.avg)
|
||||
|
||||
return res_det, res
|
||||
|
||||
@ -181,176 +266,287 @@ def test(model, epoch, data_loader, det_criterion, params):
|
||||
def parse_gt_data(inputs):
|
||||
# reads the torch arrays into a dictionary of numpy arrays, taking care to
|
||||
# remove padding data i.e. not valid ones
|
||||
keys = ['start_times', 'end_times', 'low_freqs', 'high_freqs', 'class_ids', 'individual_ids']
|
||||
keys = [
|
||||
"start_times",
|
||||
"end_times",
|
||||
"low_freqs",
|
||||
"high_freqs",
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
]
|
||||
batch_data = []
|
||||
for ind in range(inputs['start_times'].shape[0]):
|
||||
is_valid = inputs['is_valid'][ind]==1
|
||||
for ind in range(inputs["start_times"].shape[0]):
|
||||
is_valid = inputs["is_valid"][ind] == 1
|
||||
gt = {}
|
||||
for kk in keys:
|
||||
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
|
||||
gt['duration'] = inputs['duration'][ind].item()
|
||||
gt['file_id'] = inputs['file_id'][ind].item()
|
||||
gt['class_id_file'] = inputs['class_id_file'][ind].item()
|
||||
gt["duration"] = inputs["duration"][ind].item()
|
||||
gt["file_id"] = inputs["file_id"][ind].item()
|
||||
gt["class_id_file"] = inputs["class_id_file"][ind].item()
|
||||
batch_data.append(gt)
|
||||
return batch_data
|
||||
|
||||
|
||||
def select_model(params):
|
||||
num_classes = len(params['class_names'])
|
||||
if params['model_name'] == 'Net2DFast':
|
||||
model = models.Net2DFast(params['num_filters'], num_classes=num_classes,
|
||||
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
||||
resize_factor=params['resize_factor'])
|
||||
elif params['model_name'] == 'Net2DFastNoAttn':
|
||||
model = models.Net2DFastNoAttn(params['num_filters'], num_classes=num_classes,
|
||||
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
||||
resize_factor=params['resize_factor'])
|
||||
elif params['model_name'] == 'Net2DFastNoCoordConv':
|
||||
model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=num_classes,
|
||||
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
||||
resize_factor=params['resize_factor'])
|
||||
num_classes = len(params["class_names"])
|
||||
if params["model_name"] == "Net2DFast":
|
||||
model = models.Net2DFast(
|
||||
params["num_filters"],
|
||||
num_classes=num_classes,
|
||||
emb_dim=params["emb_dim"],
|
||||
ip_height=params["ip_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
elif params["model_name"] == "Net2DFastNoAttn":
|
||||
model = models.Net2DFastNoAttn(
|
||||
params["num_filters"],
|
||||
num_classes=num_classes,
|
||||
emb_dim=params["emb_dim"],
|
||||
ip_height=params["ip_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
elif params["model_name"] == "Net2DFastNoCoordConv":
|
||||
model = models.Net2DFastNoCoordConv(
|
||||
params["num_filters"],
|
||||
num_classes=num_classes,
|
||||
emb_dim=params["emb_dim"],
|
||||
ip_height=params["ip_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
else:
|
||||
print('No valid network specified')
|
||||
print("No valid network specified")
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
plt.close('all')
|
||||
plt.close("all")
|
||||
|
||||
params = parameters.get_params(True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
params['device'] = 'cuda'
|
||||
params["device"] = "cuda"
|
||||
else:
|
||||
params['device'] = 'cpu'
|
||||
params["device"] = "cpu"
|
||||
|
||||
# setup arg parser and populate it with exiting parameters - will not work with lists
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('data_dir', type=str,
|
||||
help='Path to root of datasets')
|
||||
parser.add_argument('ann_dir', type=str,
|
||||
help='Path to extracted annotations')
|
||||
parser.add_argument('--train_split', type=str, default='diff', # diff, same
|
||||
help='Which train split to use')
|
||||
parser.add_argument('--notes', type=str, default='',
|
||||
help='Notes to save in text file')
|
||||
parser.add_argument('--do_not_save_images', action='store_false',
|
||||
help='Do not save images at the end of training')
|
||||
parser.add_argument('--standardize_classs_names_ip', type=str,
|
||||
default='Rhinolophus ferrumequinum;Rhinolophus hipposideros',
|
||||
help='Will set low and high frequency the same for these classes. Separate names with ";"')
|
||||
parser.add_argument("data_dir", type=str, help="Path to root of datasets")
|
||||
parser.add_argument(
|
||||
"ann_dir", type=str, help="Path to extracted annotations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_split",
|
||||
type=str,
|
||||
default="diff", # diff, same
|
||||
help="Which train split to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--notes", type=str, default="", help="Notes to save in text file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_not_save_images",
|
||||
action="store_false",
|
||||
help="Do not save images at the end of training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--standardize_classs_names_ip",
|
||||
type=str,
|
||||
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
|
||||
help='Will set low and high frequency the same for these classes. Separate names with ";"',
|
||||
)
|
||||
for key, val in params.items():
|
||||
parser.add_argument('--'+key, type=type(val), default=val)
|
||||
parser.add_argument("--" + key, type=type(val), default=val)
|
||||
params = vars(parser.parse_args())
|
||||
|
||||
# save notes file
|
||||
if params['notes'] != '':
|
||||
tu.write_notes_file(params['experiment'] + 'notes.txt', params['notes'])
|
||||
if params["notes"] != "":
|
||||
tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"])
|
||||
|
||||
# load the training and test meta data - there are different splits defined
|
||||
train_sets, test_sets = ts.get_train_test_data(params['ann_dir'], params['data_dir'], params['train_split'])
|
||||
train_sets_no_path, test_sets_no_path = ts.get_train_test_data('', '', params['train_split'])
|
||||
train_sets, test_sets = ts.get_train_test_data(
|
||||
params["ann_dir"], params["data_dir"], params["train_split"]
|
||||
)
|
||||
train_sets_no_path, test_sets_no_path = ts.get_train_test_data(
|
||||
"", "", params["train_split"]
|
||||
)
|
||||
|
||||
# keep track of what we have trained on
|
||||
params['train_sets'] = train_sets_no_path
|
||||
params['test_sets'] = test_sets_no_path
|
||||
params["train_sets"] = train_sets_no_path
|
||||
params["test_sets"] = test_sets_no_path
|
||||
|
||||
# load train annotations - merge them all together
|
||||
print('\nTraining on:')
|
||||
print("\nTraining on:")
|
||||
for tt in train_sets:
|
||||
print(tt['ann_path'])
|
||||
classes_to_ignore = params['classes_to_ignore']+params['generic_class']
|
||||
data_train, params['class_names'], params['class_inv_freq'] = \
|
||||
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus'])
|
||||
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
|
||||
params['class_names_short'] = tu.get_short_class_names(params['class_names'])
|
||||
print(tt["ann_path"])
|
||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||
(
|
||||
data_train,
|
||||
params["class_names"],
|
||||
params["class_inv_freq"],
|
||||
) = tu.load_set_of_anns(
|
||||
train_sets,
|
||||
classes_to_ignore,
|
||||
params["events_of_interest"],
|
||||
params["convert_to_genus"],
|
||||
)
|
||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
||||
params["class_names"]
|
||||
)
|
||||
params["class_names_short"] = tu.get_short_class_names(
|
||||
params["class_names"]
|
||||
)
|
||||
|
||||
# standardize the low and high frequency value for specified classes
|
||||
params['standardize_classs_names'] = params['standardize_classs_names_ip'].split(';')
|
||||
for cc in params['standardize_classs_names']:
|
||||
if cc in params['class_names']:
|
||||
params["standardize_classs_names"] = params[
|
||||
"standardize_classs_names_ip"
|
||||
].split(";")
|
||||
for cc in params["standardize_classs_names"]:
|
||||
if cc in params["class_names"]:
|
||||
data_train = tu.standardize_low_freq(data_train, cc)
|
||||
else:
|
||||
print(cc, 'not found')
|
||||
print(cc, "not found")
|
||||
|
||||
# train loader
|
||||
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
|
||||
shuffle=True, num_workers=params['num_workers'], pin_memory=True)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=params["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=params["num_workers"],
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# test set
|
||||
print('\nTesting on:')
|
||||
print("\nTesting on:")
|
||||
for tt in test_sets:
|
||||
print(tt['ann_path'])
|
||||
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus'])
|
||||
print(tt["ann_path"])
|
||||
data_test, _, _ = tu.load_set_of_anns(
|
||||
test_sets,
|
||||
classes_to_ignore,
|
||||
params["events_of_interest"],
|
||||
params["convert_to_genus"],
|
||||
)
|
||||
data_train = tu.remove_dupes(data_train, data_test)
|
||||
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
||||
# batch size of 1 because of variable file length
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
|
||||
shuffle=False, num_workers=params['num_workers'], pin_memory=True)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=params["num_workers"],
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
inputs_train = next(iter(train_loader))
|
||||
# TODO remove params['ip_height'], this is just legacy
|
||||
params['ip_height'] = int(params['spec_height']*params['resize_factor'])
|
||||
print('\ntrain batch spec size :', inputs_train['spec'].shape)
|
||||
print('class target size :', inputs_train['y_2d_classes'].shape)
|
||||
params["ip_height"] = int(params["spec_height"] * params["resize_factor"])
|
||||
print("\ntrain batch spec size :", inputs_train["spec"].shape)
|
||||
print("class target size :", inputs_train["y_2d_classes"].shape)
|
||||
|
||||
# select network
|
||||
model = select_model(params)
|
||||
model = model.to(params['device'])
|
||||
model = model.to(params["device"])
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
|
||||
#optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9)
|
||||
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
|
||||
if params['train_loss'] == 'mse':
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
|
||||
# optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer, params["num_epochs"] * len(train_loader)
|
||||
)
|
||||
if params["train_loss"] == "mse":
|
||||
det_criterion = losses.mse_loss
|
||||
elif params['train_loss'] == 'focal':
|
||||
elif params["train_loss"] == "focal":
|
||||
det_criterion = losses.focal_loss
|
||||
|
||||
# save parameters to file
|
||||
with open(params['experiment'] + 'params.json', 'w') as da:
|
||||
with open(params["experiment"] + "params.json", "w") as da:
|
||||
json.dump(params, da, indent=2, sort_keys=True)
|
||||
|
||||
# plotting
|
||||
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
|
||||
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
|
||||
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
|
||||
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
|
||||
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
|
||||
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
|
||||
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
|
||||
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
|
||||
|
||||
train_plt_ls = pu.LossPlotter(
|
||||
params["experiment"] + "train_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
["train_loss"],
|
||||
None,
|
||||
None,
|
||||
["epoch", "train_loss"],
|
||||
logy=True,
|
||||
)
|
||||
test_plt_ls = pu.LossPlotter(
|
||||
params["experiment"] + "test_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
["test_loss"],
|
||||
None,
|
||||
None,
|
||||
["epoch", "test_loss"],
|
||||
logy=True,
|
||||
)
|
||||
test_plt = pu.LossPlotter(
|
||||
params["experiment"] + "test.png",
|
||||
params["num_epochs"] + 1,
|
||||
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
|
||||
[0, 1],
|
||||
None,
|
||||
["epoch", ""],
|
||||
)
|
||||
test_plt_class = pu.LossPlotter(
|
||||
params["experiment"] + "test_avg_prec.png",
|
||||
params["num_epochs"] + 1,
|
||||
params["class_names_short"],
|
||||
[0, 1],
|
||||
params["class_names_short"],
|
||||
["epoch", "avg_prec"],
|
||||
)
|
||||
|
||||
#
|
||||
# main train loop
|
||||
for epoch in range(0, params['num_epochs']+1):
|
||||
for epoch in range(0, params["num_epochs"] + 1):
|
||||
|
||||
train_loss = train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
|
||||
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
|
||||
train_loss = train(
|
||||
model,
|
||||
epoch,
|
||||
train_loader,
|
||||
det_criterion,
|
||||
optimizer,
|
||||
scheduler,
|
||||
params,
|
||||
)
|
||||
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
|
||||
|
||||
if epoch % params['num_eval_epochs'] == 0:
|
||||
if epoch % params["num_eval_epochs"] == 0:
|
||||
# detection accuracy on test set
|
||||
test_res, test_loss = test(model, epoch, test_loader, det_criterion, params)
|
||||
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
|
||||
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
|
||||
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
|
||||
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
|
||||
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
|
||||
|
||||
test_res, test_loss = test(
|
||||
model, epoch, test_loader, det_criterion, params
|
||||
)
|
||||
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
|
||||
test_plt.update_and_save(
|
||||
epoch,
|
||||
[
|
||||
test_res["avg_prec"],
|
||||
test_res["rec_at_x"],
|
||||
test_res["avg_prec_class"],
|
||||
test_res["file_acc"],
|
||||
test_res["top_class"]["avg_prec"],
|
||||
],
|
||||
)
|
||||
test_plt_class.update_and_save(
|
||||
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
|
||||
)
|
||||
pu.plot_pr_curve_class(
|
||||
params["experiment"], "test_pr", "test_pr", test_res
|
||||
)
|
||||
|
||||
# save trained model
|
||||
print('saving model to: ' + params['model_file_name'])
|
||||
op_state = {'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
#'optimizer' : optimizer.state_dict(),
|
||||
'params' : params}
|
||||
torch.save(op_state, params['model_file_name'])
|
||||
|
||||
print("saving model to: " + params["model_file_name"])
|
||||
op_state = {
|
||||
"epoch": epoch + 1,
|
||||
"state_dict": model.state_dict(),
|
||||
#'optimizer' : optimizer.state_dict(),
|
||||
"params": params,
|
||||
}
|
||||
torch.save(op_state, params["model_file_name"])
|
||||
|
||||
# save an image with associated prediction for each batch in the test set
|
||||
if not args['do_not_save_images']:
|
||||
save_images_batch(model, test_loader, params)
|
||||
# TODO: args variable does not exist
|
||||
# if not args["do_not_save_images"]:
|
||||
# save_images_batch(model, test_loader, params)
|
||||
|
@ -2,13 +2,14 @@
|
||||
Run scripts/extract_anns.py to generate these json files.
|
||||
"""
|
||||
|
||||
|
||||
def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
||||
if split_name == 'diff':
|
||||
if split_name == "diff":
|
||||
train_sets, test_sets = split_diff(ann_dir, wav_dir, load_extra)
|
||||
elif split_name == 'same':
|
||||
elif split_name == "same":
|
||||
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
|
||||
else:
|
||||
print('Split not defined')
|
||||
print("Split not defined")
|
||||
assert False
|
||||
|
||||
return train_sets, test_sets
|
||||
@ -18,73 +19,126 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append({'dataset_name': 'BatDetective',
|
||||
'is_test': False,
|
||||
'is_binary': True, # just a bat / not bat dataset ie no classes
|
||||
'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json',
|
||||
'wav_path': wav_dir + 'bat_detective/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_logger_qeop_empty',
|
||||
'is_test': False,
|
||||
'is_binary': True,
|
||||
'ann_path': ann_dir + 'bat_logger_qeop_empty.json',
|
||||
'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_logger_2016_empty',
|
||||
'is_test': False,
|
||||
'is_binary': True,
|
||||
'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json',
|
||||
'wav_path': wav_dir + 'bat_logger_2016/audio/'})
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "BatDetective",
|
||||
"is_test": False,
|
||||
"is_binary": True, # just a bat / not bat dataset ie no classes
|
||||
"ann_path": ann_dir
|
||||
+ "train_set_bulgaria_batdetective_with_bbs.json",
|
||||
"wav_path": wav_dir + "bat_detective/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_logger_qeop_empty",
|
||||
"is_test": False,
|
||||
"is_binary": True,
|
||||
"ann_path": ann_dir + "bat_logger_qeop_empty.json",
|
||||
"wav_path": wav_dir + "bat_logger_qeop_empty/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_logger_2016_empty",
|
||||
"is_test": False,
|
||||
"is_binary": True,
|
||||
"ann_path": ann_dir + "train_set_bat_logger_2016_empty.json",
|
||||
"wav_path": wav_dir + "bat_logger_2016/audio/",
|
||||
}
|
||||
)
|
||||
# train_sets.append({'dataset_name': 'brazil_data_binary',
|
||||
# 'is_test': False,
|
||||
# 'ann_path': ann_dir + 'brazil_data_binary.json',
|
||||
# 'wav_path': wav_dir + 'brazil_data/audio/'})
|
||||
|
||||
train_sets.append({'dataset_name': 'echobank',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'Echobank_train_expert.json',
|
||||
'wav_path': wav_dir + 'echobank/audio/'})
|
||||
train_sets.append({'dataset_name': 'sn_scot_nor',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert.json',
|
||||
'wav_path': wav_dir + 'sn_scot_nor/audio/'})
|
||||
train_sets.append({'dataset_name': 'BCT_1_sec',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BCT_1_sec_train_expert.json',
|
||||
'wav_path': wav_dir + 'BCT_1_sec/audio/'})
|
||||
train_sets.append({'dataset_name': 'bcireland',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'bcireland_expert.json',
|
||||
'wav_path': wav_dir + 'bcireland/audio/'})
|
||||
train_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert.json',
|
||||
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "echobank",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "Echobank_train_expert.json",
|
||||
"wav_path": wav_dir + "echobank/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "sn_scot_nor",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "sn_scot_nor_0.5_expert.json",
|
||||
"wav_path": wav_dir + "sn_scot_nor/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "BCT_1_sec",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "BCT_1_sec_train_expert.json",
|
||||
"wav_path": wav_dir + "BCT_1_sec/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bcireland",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "bcireland_expert.json",
|
||||
"wav_path": wav_dir + "bcireland/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "rhinolophus_steve_BCT",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "rhinolophus_steve_BCT_expert.json",
|
||||
"wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
|
||||
}
|
||||
)
|
||||
|
||||
test_sets = []
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2018',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2019',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2018",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2018/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2018_test",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2019",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2019/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2019_test",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
|
||||
}
|
||||
)
|
||||
|
||||
return train_sets, test_sets
|
||||
|
||||
@ -93,71 +147,124 @@ def split_same(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append({'dataset_name': 'BatDetective',
|
||||
'is_test': False,
|
||||
'is_binary': True,
|
||||
'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json',
|
||||
'wav_path': wav_dir + 'bat_detective/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_logger_qeop_empty',
|
||||
'is_test': False,
|
||||
'is_binary': True,
|
||||
'ann_path': ann_dir + 'bat_logger_qeop_empty.json',
|
||||
'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_logger_2016_empty',
|
||||
'is_test': False,
|
||||
'is_binary': True,
|
||||
'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json',
|
||||
'wav_path': wav_dir + 'bat_logger_2016/audio/'})
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "BatDetective",
|
||||
"is_test": False,
|
||||
"is_binary": True,
|
||||
"ann_path": ann_dir
|
||||
+ "train_set_bulgaria_batdetective_with_bbs.json",
|
||||
"wav_path": wav_dir + "bat_detective/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_logger_qeop_empty",
|
||||
"is_test": False,
|
||||
"is_binary": True,
|
||||
"ann_path": ann_dir + "bat_logger_qeop_empty.json",
|
||||
"wav_path": wav_dir + "bat_logger_qeop_empty/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_logger_2016_empty",
|
||||
"is_test": False,
|
||||
"is_binary": True,
|
||||
"ann_path": ann_dir + "train_set_bat_logger_2016_empty.json",
|
||||
"wav_path": wav_dir + "bat_logger_2016/audio/",
|
||||
}
|
||||
)
|
||||
# train_sets.append({'dataset_name': 'brazil_data_binary',
|
||||
# 'is_test': False,
|
||||
# 'ann_path': ann_dir + 'brazil_data_binary.json',
|
||||
# 'wav_path': wav_dir + 'brazil_data/audio/'})
|
||||
|
||||
train_sets.append({'dataset_name': 'echobank',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'Echobank_train_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'echobank/audio/'})
|
||||
train_sets.append({'dataset_name': 'sn_scot_nor',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'sn_scot_nor/audio/'})
|
||||
train_sets.append({'dataset_name': 'BCT_1_sec',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BCT_1_sec_train_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'BCT_1_sec/audio/'})
|
||||
train_sets.append({'dataset_name': 'bcireland',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'bcireland_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'bcireland/audio/'})
|
||||
train_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_data_martyn_2018',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_data_martyn_2019',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
|
||||
train_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
|
||||
'is_test': False,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "echobank",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "Echobank_train_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "echobank/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "sn_scot_nor",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "sn_scot_nor_0.5_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "sn_scot_nor/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "BCT_1_sec",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "BCT_1_sec_train_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "BCT_1_sec/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bcireland",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "bcireland_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "bcireland/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "rhinolophus_steve_BCT",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2018",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2018/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2018_test",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2019",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2019/audio/",
|
||||
}
|
||||
)
|
||||
train_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2019_test",
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
|
||||
}
|
||||
)
|
||||
|
||||
# train_sets.append({'dataset_name': 'bat_data_martyn_2021_train',
|
||||
# 'is_test': False,
|
||||
@ -171,51 +278,91 @@ def split_same(ann_dir, wav_dir, load_extra=True):
|
||||
# 'wav_path': wav_dir + 'volunteers_2021/audio/'})
|
||||
|
||||
test_sets = []
|
||||
test_sets.append({'dataset_name': 'echobank',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'Echobank_train_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'echobank/audio/'})
|
||||
test_sets.append({'dataset_name': 'sn_scot_nor',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'sn_scot_nor/audio/'})
|
||||
test_sets.append({'dataset_name': 'BCT_1_sec',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BCT_1_sec_train_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'BCT_1_sec/audio/'})
|
||||
test_sets.append({'dataset_name': 'bcireland',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'bcireland_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'bcireland/audio/'})
|
||||
test_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2018',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2019',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
|
||||
test_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
|
||||
'is_test': True,
|
||||
'is_binary': False,
|
||||
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json',
|
||||
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "echobank",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "Echobank_train_expert_TEST.json",
|
||||
"wav_path": wav_dir + "echobank/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "sn_scot_nor",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "sn_scot_nor_0.5_expert_TEST.json",
|
||||
"wav_path": wav_dir + "sn_scot_nor/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "BCT_1_sec",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "BCT_1_sec_train_expert_TEST.json",
|
||||
"wav_path": wav_dir + "BCT_1_sec/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bcireland",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "bcireland_expert_TEST.json",
|
||||
"wav_path": wav_dir + "bcireland/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "rhinolophus_steve_BCT",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TEST.json",
|
||||
"wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2018",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2018/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2018_test",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2019",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2019/audio/",
|
||||
}
|
||||
)
|
||||
test_sets.append(
|
||||
{
|
||||
"dataset_name": "bat_data_martyn_2019_test",
|
||||
"is_test": True,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_dir
|
||||
+ "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json",
|
||||
"wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
|
||||
}
|
||||
)
|
||||
|
||||
# test_sets.append({'dataset_name': 'bat_data_martyn_2021_test',
|
||||
# 'is_test': True,
|
||||
|
@ -1,42 +1,52 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def write_notes_file(file_name, text):
|
||||
with open(file_name, 'a') as da:
|
||||
da.write(text + '\n')
|
||||
with open(file_name, "a") as da:
|
||||
da.write(text + "\n")
|
||||
|
||||
|
||||
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
|
||||
ddict = {'dataset_name': dataset_name, 'is_test': is_test, 'is_binary': False,
|
||||
'ann_path': ann_path, 'wav_path': wav_path}
|
||||
ddict = {
|
||||
"dataset_name": dataset_name,
|
||||
"is_test": is_test,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_path,
|
||||
"wav_path": wav_path,
|
||||
}
|
||||
return ddict
|
||||
|
||||
|
||||
def get_short_class_names(class_names, str_len=3):
|
||||
class_names_short = []
|
||||
for cc in class_names:
|
||||
class_names_short.append(' '.join([sp[:str_len] for sp in cc.split(' ')]))
|
||||
class_names_short.append(
|
||||
" ".join([sp[:str_len] for sp in cc.split(" ")])
|
||||
)
|
||||
return class_names_short
|
||||
|
||||
|
||||
def remove_dupes(data_train, data_test):
|
||||
test_ids = [dd['id'] for dd in data_test]
|
||||
test_ids = [dd["id"] for dd in data_test]
|
||||
data_train_prune = []
|
||||
for aa in data_train:
|
||||
if aa['id'] not in test_ids:
|
||||
if aa["id"] not in test_ids:
|
||||
data_train_prune.append(aa)
|
||||
diff = len(data_train) - len(data_train_prune)
|
||||
if diff != 0:
|
||||
print(diff, 'items removed from train set')
|
||||
print(diff, "items removed from train set")
|
||||
return data_train_prune
|
||||
|
||||
|
||||
def get_genus_mapping(class_names):
|
||||
genus_names, genus_mapping = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True)
|
||||
genus_names, genus_mapping = np.unique(
|
||||
[cc.split(" ")[0] for cc in class_names], return_inverse=True
|
||||
)
|
||||
return genus_names.tolist(), genus_mapping.tolist()
|
||||
|
||||
|
||||
@ -47,97 +57,110 @@ def standardize_low_freq(data, class_of_interest):
|
||||
low_freqs = []
|
||||
high_freqs = []
|
||||
for dd in data:
|
||||
for aa in dd['annotation']:
|
||||
if aa['class'] == class_of_interest:
|
||||
low_freqs.append(aa['low_freq'])
|
||||
high_freqs.append(aa['high_freq'])
|
||||
for aa in dd["annotation"]:
|
||||
if aa["class"] == class_of_interest:
|
||||
low_freqs.append(aa["low_freq"])
|
||||
high_freqs.append(aa["high_freq"])
|
||||
|
||||
low_mean = np.mean(low_freqs)
|
||||
high_mean = np.mean(high_freqs)
|
||||
assert(low_mean < high_mean)
|
||||
assert low_mean < high_mean
|
||||
|
||||
print('\nStandardizing low and high frequency for:')
|
||||
print("\nStandardizing low and high frequency for:")
|
||||
print(class_of_interest)
|
||||
print('low: ', round(low_mean, 2))
|
||||
print('high: ', round(high_mean, 2))
|
||||
print("low: ", round(low_mean, 2))
|
||||
print("high: ", round(high_mean, 2))
|
||||
|
||||
# only set the low freq, high stays the same
|
||||
# assumes that low_mean < high_mean
|
||||
for dd in data:
|
||||
for aa in dd['annotation']:
|
||||
if aa['class'] == class_of_interest:
|
||||
aa['low_freq'] = low_mean
|
||||
if aa['high_freq'] < low_mean:
|
||||
aa['high_freq'] = high_mean
|
||||
for aa in dd["annotation"]:
|
||||
if aa["class"] == class_of_interest:
|
||||
aa["low_freq"] = low_mean
|
||||
if aa["high_freq"] < low_mean:
|
||||
aa["high_freq"] = high_mean
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_set_of_anns(data, classes_to_ignore=[], events_of_interest=None,
|
||||
convert_to_genus=False, verbose=True, list_of_anns=False,
|
||||
filter_issues=False, name_replace=False):
|
||||
def load_set_of_anns(
|
||||
data,
|
||||
classes_to_ignore=[],
|
||||
events_of_interest=None,
|
||||
convert_to_genus=False,
|
||||
verbose=True,
|
||||
list_of_anns=False,
|
||||
filter_issues=False,
|
||||
name_replace=False,
|
||||
):
|
||||
|
||||
# load the annotations
|
||||
anns = []
|
||||
if list_of_anns:
|
||||
# path to list of individual json files
|
||||
anns.extend(load_anns_from_path(data['ann_path'], data['wav_path']))
|
||||
anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
|
||||
else:
|
||||
# dictionary of datasets
|
||||
for dd in data:
|
||||
anns.extend(load_anns(dd['ann_path'], dd['wav_path']))
|
||||
anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
|
||||
|
||||
# discarding unannoated files
|
||||
anns = [aa for aa in anns if aa['annotated'] is True]
|
||||
anns = [aa for aa in anns if aa["annotated"] is True]
|
||||
|
||||
# filter files that have annotation issues - is the input is a dictionary of
|
||||
# datasets, this will lilely have already been done
|
||||
if filter_issues:
|
||||
anns = [aa for aa in anns if aa['issues'] is False]
|
||||
anns = [aa for aa in anns if aa["issues"] is False]
|
||||
|
||||
# check for some basic formatting errors with class names
|
||||
for ann in anns:
|
||||
for aa in ann['annotation']:
|
||||
aa['class'] = aa['class'].strip()
|
||||
for aa in ann["annotation"]:
|
||||
aa["class"] = aa["class"].strip()
|
||||
|
||||
# only load specified events - i.e. types of calls
|
||||
if events_of_interest is not None:
|
||||
for ann in anns:
|
||||
filtered_events = []
|
||||
for aa in ann['annotation']:
|
||||
if aa['event'] in events_of_interest:
|
||||
for aa in ann["annotation"]:
|
||||
if aa["event"] in events_of_interest:
|
||||
filtered_events.append(aa)
|
||||
ann['annotation'] = filtered_events
|
||||
ann["annotation"] = filtered_events
|
||||
|
||||
# change class names
|
||||
# replace_names will be a dictionary mapping input name to output
|
||||
if type(name_replace) is dict:
|
||||
for ann in anns:
|
||||
for aa in ann['annotation']:
|
||||
if aa['class'] in name_replace:
|
||||
aa['class'] = name_replace[aa['class']]
|
||||
for aa in ann["annotation"]:
|
||||
if aa["class"] in name_replace:
|
||||
aa["class"] = name_replace[aa["class"]]
|
||||
|
||||
# convert everything to genus name
|
||||
if convert_to_genus:
|
||||
for ann in anns:
|
||||
for aa in ann['annotation']:
|
||||
aa['class'] = aa['class'].split(' ')[0]
|
||||
for aa in ann["annotation"]:
|
||||
aa["class"] = aa["class"].split(" ")[0]
|
||||
|
||||
# get unique class names
|
||||
class_names_all = []
|
||||
for ann in anns:
|
||||
for aa in ann['annotation']:
|
||||
if aa['class'] not in classes_to_ignore:
|
||||
class_names_all.append(aa['class'])
|
||||
for aa in ann["annotation"]:
|
||||
if aa["class"] not in classes_to_ignore:
|
||||
class_names_all.append(aa["class"])
|
||||
|
||||
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
|
||||
class_inv_freq = (class_cnts.sum() / (len(class_names) * class_cnts.astype(np.float32)))
|
||||
class_inv_freq = class_cnts.sum() / (
|
||||
len(class_names) * class_cnts.astype(np.float32)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print('Class count:')
|
||||
print("Class count:")
|
||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||
for cc in range(len(class_names)):
|
||||
print(str(cc).ljust(5) + class_names[cc].ljust(str_len) + str(class_cnts[cc]))
|
||||
print(
|
||||
str(cc).ljust(5)
|
||||
+ class_names[cc].ljust(str_len)
|
||||
+ str(class_cnts[cc])
|
||||
)
|
||||
|
||||
if len(classes_to_ignore) == 0:
|
||||
return anns
|
||||
@ -150,36 +173,37 @@ def load_anns(ann_file_name, raw_audio_dir):
|
||||
anns = json.load(da)
|
||||
|
||||
for aa in anns:
|
||||
aa['file_path'] = raw_audio_dir + aa['id']
|
||||
aa["file_path"] = raw_audio_dir + aa["id"]
|
||||
|
||||
return anns
|
||||
|
||||
|
||||
def load_anns_from_path(ann_file_dir, raw_audio_dir):
|
||||
files = glob.glob(ann_file_dir + '*.json')
|
||||
files = glob.glob(ann_file_dir + "*.json")
|
||||
anns = []
|
||||
for ff in files:
|
||||
with open(ff) as da:
|
||||
ann = json.load(da)
|
||||
ann['file_path'] = raw_audio_dir + ann['id']
|
||||
ann["file_path"] = raw_audio_dir + ann["id"]
|
||||
anns.append(ann)
|
||||
|
||||
return anns
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
475
bat_detect/types.py
Normal file
475
bat_detect/types.py
Normal file
@ -0,0 +1,475 @@
|
||||
"""Types used in the code base."""
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Annotation",
|
||||
"DetectionModel",
|
||||
"FileAnnotations",
|
||||
"ModelOutput",
|
||||
"ModelParameters",
|
||||
"NonMaximumSuppressionConfig",
|
||||
"PredictionResults",
|
||||
"ProcessingConfiguration",
|
||||
"ResultParams",
|
||||
"RunResults",
|
||||
"SpectrogramParameters",
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Percentage of overlap between FFT windows."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to display in the spectrogram."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to display in the spectrogram."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
|
||||
class ModelParameters(TypedDict):
|
||||
"""Model parameters."""
|
||||
|
||||
model_name: str
|
||||
"""Model name."""
|
||||
|
||||
num_filters: int
|
||||
"""Number of filters."""
|
||||
|
||||
emb_dim: int
|
||||
"""Embedding dimension."""
|
||||
|
||||
ip_height: int
|
||||
"""Input height in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names. The model is trained to detect these classes."""
|
||||
|
||||
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
|
||||
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
"""Start time in seconds."""
|
||||
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: int
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: int
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
"""Probability of class assignment."""
|
||||
|
||||
det_prob: float
|
||||
"""Probability of detection."""
|
||||
|
||||
individual: str
|
||||
"""Individual ID."""
|
||||
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
|
||||
class FileAnnotations(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""File ID."""
|
||||
|
||||
annotated: bool
|
||||
"""Whether file has been annotated."""
|
||||
|
||||
duration: float
|
||||
"""Duration of audio file."""
|
||||
|
||||
issues: bool
|
||||
"""Whether file has issues."""
|
||||
|
||||
time_exp: float
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level"""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
|
||||
class RunResults(TypedDict):
|
||||
"""Run results."""
|
||||
|
||||
pred_dict: FileAnnotations
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: NotRequired[List[np.ndarray]]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: NotRequired[List[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: NotRequired[List[np.ndarray]]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: NotRequired[List[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: NotRequired[List[np.ndarray]]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
class ResultParams(TypedDict):
|
||||
"""Result parameters."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names."""
|
||||
|
||||
spec_features: bool
|
||||
"""Whether to return spectrogram features."""
|
||||
|
||||
cnn_features: bool
|
||||
"""Whether to return CNN features."""
|
||||
|
||||
spec_slices: bool
|
||||
"""Whether to return spectrogram slices."""
|
||||
|
||||
|
||||
class ProcessingConfiguration(TypedDict):
|
||||
"""Parameters for processing audio files."""
|
||||
|
||||
# audio parameters
|
||||
target_samp_rate: int
|
||||
"""Target sampling rate of the audio."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Length of the FFT window in samples."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Names of the classes the model can detect."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
time_expansion: Optional[float]
|
||||
"""Time expansion factor of the processed recordings."""
|
||||
|
||||
top_n: int
|
||||
"""Number of top detections to keep."""
|
||||
|
||||
return_raw_preds: bool
|
||||
"""Whether to return raw predictions."""
|
||||
|
||||
max_duration: Optional[float]
|
||||
"""Maximum duration of audio file to process in seconds."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
quiet: bool
|
||||
"""Whether to suppress output."""
|
||||
|
||||
chunk_size: float
|
||||
"""Size of chunks to process in seconds."""
|
||||
|
||||
cnn_features: bool
|
||||
"""Whether to return CNN features."""
|
||||
|
||||
spec_features: bool
|
||||
"""Whether to return spectrogram features."""
|
||||
|
||||
spec_slices: bool
|
||||
"""Whether to return spectrogram slices."""
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model.
|
||||
|
||||
Each of the tensors has a shape of
|
||||
|
||||
`(batch_size, num_channels,spec_height, spec_width)`.
|
||||
|
||||
Where `spec_height` and `spec_width` are the height and width of the
|
||||
input spectrograms.
|
||||
|
||||
They contain localised information of:
|
||||
|
||||
1. The probability of a bounding box detection at the given location.
|
||||
2. The predicted size of the bounding box at the given location.
|
||||
3. The probabilities of each class at the given location.
|
||||
4. Same as 3. but before softmax.
|
||||
5. Features used to make the predictions at the given location.
|
||||
"""
|
||||
|
||||
pred_det: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
|
||||
pred_size: torch.Tensor
|
||||
"""Tensor with predicted bounding box sizes."""
|
||||
|
||||
pred_class: torch.Tensor
|
||||
"""Tensor with predicted class probabilities."""
|
||||
|
||||
pred_class_un_norm: torch.Tensor
|
||||
"""Tensor with predicted class probabilities before softmax."""
|
||||
|
||||
features: torch.Tensor
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class PredictionResults(TypedDict):
|
||||
"""Results of the prediction.
|
||||
|
||||
Each key is a list of length `num_detections` containing the
|
||||
corresponding values for each detection.
|
||||
"""
|
||||
|
||||
det_probs: np.ndarray
|
||||
"""Detection probabilities."""
|
||||
|
||||
x_pos: np.ndarray
|
||||
"""X position of the detection in pixels."""
|
||||
|
||||
y_pos: np.ndarray
|
||||
"""Y position of the detection in pixels."""
|
||||
|
||||
bb_width: np.ndarray
|
||||
"""Width of the detection in pixels."""
|
||||
|
||||
bb_height: np.ndarray
|
||||
"""Height of the detection in pixels."""
|
||||
|
||||
start_times: np.ndarray
|
||||
"""Start times of the detections in seconds."""
|
||||
|
||||
end_times: np.ndarray
|
||||
"""End times of the detections in seconds."""
|
||||
|
||||
low_freqs: np.ndarray
|
||||
"""Low frequencies of the detections in Hz."""
|
||||
|
||||
high_freqs: np.ndarray
|
||||
"""High frequencies of the detections in Hz."""
|
||||
|
||||
class_probs: np.ndarray
|
||||
"""Class probabilities."""
|
||||
|
||||
|
||||
class DetectionModel(Protocol):
|
||||
"""Protocol for detection models.
|
||||
|
||||
This protocol is used to define the interface for the detection models.
|
||||
This allows us to use the same code for training and inference, even
|
||||
though the models are different.
|
||||
"""
|
||||
|
||||
num_classes: int
|
||||
"""Number of classes the model can classify."""
|
||||
|
||||
emb_dim: int
|
||||
"""Dimension of the embedding vector."""
|
||||
|
||||
num_filts: int
|
||||
"""Number of filters in the model."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input is resized."""
|
||||
|
||||
ip_height_rs: int
|
||||
"""Height of the input image."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
|
||||
|
||||
class NonMaximumSuppressionConfig(TypedDict):
|
||||
"""Configuration for non-maximum suppression."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Overlap of the FFT windows in seconds."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input was resized."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
|
||||
class HeatmapParameters(TypedDict):
|
||||
"""Parameters that control the heatmap generation function."""
|
||||
|
||||
class_names: List[str]
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Percentage of the FFT windows overlap."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input was resized."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
target_sigma: float
|
||||
"""Sigma for the Gaussian kernel. Controls the width of the points in
|
||||
the heatmap."""
|
||||
|
||||
|
||||
class AnnotationGroup(TypedDict):
|
||||
"""Group of annotations.
|
||||
|
||||
Each key is a numpy array of length `num_annotations` containing the
|
||||
corresponding values for each annotation.
|
||||
"""
|
||||
|
||||
start_times: np.ndarray
|
||||
"""Start times of the annotations in seconds."""
|
||||
|
||||
end_times: np.ndarray
|
||||
"""End times of the annotations in seconds."""
|
||||
|
||||
low_freqs: np.ndarray
|
||||
"""Low frequencies of the annotations in Hz."""
|
||||
|
||||
high_freqs: np.ndarray
|
||||
"""High frequencies of the annotations in Hz."""
|
||||
|
||||
class_ids: np.ndarray
|
||||
"""Class IDs of the annotations."""
|
||||
|
||||
individual_ids: np.ndarray
|
||||
"""Individual IDs of the annotations."""
|
||||
|
||||
x_inds: NotRequired[np.ndarray]
|
||||
"""X coordinate of the annotations in the spectrogram."""
|
||||
|
||||
y_inds: NotRequired[np.ndarray]
|
||||
"""Y coordinate of the annotations in the spectrogram."""
|
@ -1,91 +1,207 @@
|
||||
import numpy as np
|
||||
from . import wavfile
|
||||
import warnings
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from bat_detect.detector.parameters import (
|
||||
DENOISE_SPEC_AVG,
|
||||
DETECTION_THRESHOLD,
|
||||
FFT_OVERLAP,
|
||||
FFT_WIN_LENGTH_S,
|
||||
MAX_FREQ_HZ,
|
||||
MAX_SCALE_SPEC,
|
||||
MIN_FREQ_HZ,
|
||||
NMS_KERNEL_SIZE,
|
||||
NMS_TOP_K_PER_SEC,
|
||||
RESIZE_FACTOR,
|
||||
SCALE_RAW_AUDIO,
|
||||
SPEC_DIVIDE_FACTOR,
|
||||
SPEC_HEIGHT,
|
||||
SPEC_SCALE,
|
||||
)
|
||||
|
||||
from . import wavfile
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
__all__ = [
|
||||
"load_audio",
|
||||
"generate_spectrogram",
|
||||
"pad_audio",
|
||||
"SpectrogramParameters",
|
||||
"DEFAULT_SPECTROGRAM_PARAMETERS",
|
||||
]
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length*sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap*nfft)
|
||||
return (time_in_file*sampling_rate-noverlap) / (nfft - noverlap)
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||
|
||||
|
||||
# NOTE this is also defined in post_process
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length*sampling_rate)
|
||||
noverlap = np.floor(fft_overlap*nfft)
|
||||
return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate
|
||||
#return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
|
||||
def generate_spectrogram(audio, sampling_rate, params, return_spec_for_viz=False, check_spec_size=True):
|
||||
def generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params,
|
||||
return_spec_for_viz=False,
|
||||
check_spec_size=True,
|
||||
):
|
||||
|
||||
# generate spectrogram
|
||||
spec = gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
||||
spec = gen_mag_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
|
||||
# crop to min/max freq
|
||||
max_freq = round(params['max_freq']*params['fft_win_length'])
|
||||
min_freq = round(params['min_freq']*params['fft_win_length'])
|
||||
max_freq = round(params["max_freq"] * params["fft_win_length"])
|
||||
min_freq = round(params["min_freq"] * params["fft_win_length"])
|
||||
if spec.shape[0] < max_freq:
|
||||
freq_pad = max_freq - spec.shape[0]
|
||||
spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec))
|
||||
spec_cropped = spec[-max_freq:spec.shape[0]-min_freq, :]
|
||||
spec = np.vstack(
|
||||
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
||||
)
|
||||
spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
|
||||
if params['spec_scale'] == 'log':
|
||||
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
|
||||
#log_scaling = (1.0 / sampling_rate)*0.1
|
||||
#log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling*spec_cropped)
|
||||
elif params['spec_scale'] == 'pcen':
|
||||
if params["spec_scale"] == "log":
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (
|
||||
1.0
|
||||
/ (
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
# log_scaling = (1.0 / sampling_rate)*0.1
|
||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling * spec_cropped)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec_cropped, sampling_rate)
|
||||
elif params['spec_scale'] == 'none':
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
|
||||
if params['denoise_spec_avg']:
|
||||
if params["denoise_spec_avg"]:
|
||||
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
||||
spec.clip(min=0, out=spec)
|
||||
|
||||
if params['max_scale_spec']:
|
||||
if params["max_scale_spec"]:
|
||||
spec = spec / (spec.max() + 10e-6)
|
||||
|
||||
# needs to be divisible by specific factor - if not it should have been padded
|
||||
#if check_spec_size:
|
||||
#assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
#assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
# if check_spec_size:
|
||||
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
|
||||
# for visualization purposes - use log scaled spectrogram
|
||||
if return_spec_for_viz:
|
||||
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
|
||||
spec_for_viz = np.log1p(log_scaling*spec_cropped).astype(np.float32)
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (
|
||||
1.0
|
||||
/ (
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
|
||||
else:
|
||||
spec_for_viz = None
|
||||
|
||||
return spec, spec_for_viz
|
||||
|
||||
|
||||
def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, max_duration=False):
|
||||
def load_audio(
|
||||
audio_file: str,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Args:
|
||||
audio_file (str): Path to the audio file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
|
||||
Returns:
|
||||
sampling_rate: The sampling rate of the audio.
|
||||
audio_raw: The audio signal in a numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=wavfile.WavFileWarning)
|
||||
#sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||
audio_raw, sampling_rate = librosa.load(audio_file, sr=None)
|
||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||
audio_raw, sampling_rate = librosa.load(
|
||||
audio_file,
|
||||
sr=None,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
if len(audio_raw.shape) > 1:
|
||||
raise Exception('Currently does not handle stereo files')
|
||||
raise ValueError("Currently does not handle stereo files")
|
||||
|
||||
sampling_rate = sampling_rate * time_exp_fact
|
||||
|
||||
# resample - need to do this after correcting for time expansion
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = target_samp_rate
|
||||
audio_raw = librosa.resample(audio_raw, orig_sr=sampling_rate_old, target_sr=sampling_rate, res_type='polyphase')
|
||||
if sampling_rate_old != sampling_rate:
|
||||
audio_raw = librosa.resample(
|
||||
audio_raw,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
# clipping maximum duration
|
||||
if max_duration is not False:
|
||||
max_duration = np.minimum(int(sampling_rate*max_duration), audio_raw.shape[0])
|
||||
if max_duration is not None:
|
||||
max_duration = int(
|
||||
np.minimum(
|
||||
int(sampling_rate * max_duration),
|
||||
audio_raw.shape[0],
|
||||
)
|
||||
)
|
||||
audio_raw = audio_raw[:max_duration]
|
||||
|
||||
# convert to float32 and scale
|
||||
audio_raw = audio_raw.astype(np.float32)
|
||||
|
||||
# scale to [-1, 1]
|
||||
if scale:
|
||||
audio_raw = audio_raw - audio_raw.mean()
|
||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||
@ -93,38 +209,53 @@ def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, ma
|
||||
return sampling_rate, audio_raw
|
||||
|
||||
|
||||
def pad_audio(audio_raw, fs, ms, overlap_perc, resize_factor, divide_factor, fixed_width=None):
|
||||
def pad_audio(
|
||||
audio_raw,
|
||||
fs,
|
||||
ms,
|
||||
overlap_perc,
|
||||
resize_factor,
|
||||
divide_factor,
|
||||
fixed_width=None,
|
||||
):
|
||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||
# will be evenly divisible by `divide_factor`
|
||||
# Also deals with very short audio clips and fixed_width during training
|
||||
|
||||
# This code could be clearer, clean up
|
||||
nfft = int(ms*fs)
|
||||
noverlap = int(overlap_perc*nfft)
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
step = nfft - noverlap
|
||||
min_size = int(divide_factor*(1.0/resize_factor))
|
||||
spec_width = ((audio_raw.shape[0]-noverlap)//step)
|
||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||
spec_width = (audio_raw.shape[0] - noverlap) // step
|
||||
spec_width_rs = spec_width * resize_factor
|
||||
|
||||
if fixed_width is not None and spec_width < fixed_width:
|
||||
# too small
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width*step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype)))
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
)
|
||||
|
||||
elif fixed_width is not None and spec_width > fixed_width:
|
||||
# too big
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width*step + noverlap - audio_raw.shape[0]
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = audio_raw[:diff]
|
||||
|
||||
elif spec_width_rs < min_size or (np.floor(spec_width_rs) % divide_factor) != 0:
|
||||
elif (
|
||||
spec_width_rs < min_size
|
||||
or (np.floor(spec_width_rs) % divide_factor) != 0
|
||||
):
|
||||
# need to be at least min_size
|
||||
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
||||
div_amt = np.maximum(1, div_amt)
|
||||
target_size = int(div_amt*divide_factor*(1.0/resize_factor))
|
||||
diff = target_size*step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype)))
|
||||
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
|
||||
diff = target_size * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
)
|
||||
|
||||
return audio_raw
|
||||
|
||||
@ -133,14 +264,16 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
# Computes magnitude spectrogram by specifying time.
|
||||
|
||||
x = x.astype(np.float32)
|
||||
nfft = int(ms*fs)
|
||||
noverlap = int(overlap_perc*nfft)
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
|
||||
# window data
|
||||
step = nfft - noverlap
|
||||
|
||||
# compute spec
|
||||
spec, _ = librosa.core.spectrum._spectrogram(y=x, power=1, n_fft=nfft, hop_length=step, center=False)
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=x, power=1, n_fft=nfft, hop_length=step, center=False
|
||||
)
|
||||
|
||||
# remove DC component and flip vertical orientation
|
||||
spec = np.flipud(spec[1:, :])
|
||||
@ -149,8 +282,8 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
|
||||
|
||||
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
|
||||
nfft = int(ms*fs)
|
||||
nstep = round((1.0-overlap_perc)*nfft)
|
||||
nfft = int(ms * fs)
|
||||
nstep = round((1.0 - overlap_perc) * nfft)
|
||||
|
||||
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
|
||||
|
||||
@ -158,12 +291,14 @@ def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
|
||||
spec = complex_spec.pow(2.0).sum(-1)
|
||||
|
||||
# remove DC component and flip vertically
|
||||
spec = torch.flipud(spec[0, 1:,:])
|
||||
spec = torch.flipud(spec[0, 1:, :])
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def pcen(spec_cropped, sampling_rate):
|
||||
# TODO should be passing hop_length too i.e. step
|
||||
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate/10).astype(np.float32)
|
||||
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
|
||||
np.float32
|
||||
)
|
||||
return spec
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,63 +1,107 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib import patches
|
||||
from matplotlib.collections import PatchCollection
|
||||
|
||||
from . import audio_utils as au
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
|
||||
def create_box_image(spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False):
|
||||
def create_box_image(
|
||||
spec,
|
||||
fig,
|
||||
detections_ip,
|
||||
start_time,
|
||||
end_time,
|
||||
duration,
|
||||
params,
|
||||
max_val,
|
||||
hide_axis=True,
|
||||
plot_class_names=False,
|
||||
):
|
||||
# filter detections
|
||||
stop_time = start_time + duration
|
||||
detections = []
|
||||
for bb in detections_ip:
|
||||
if (bb['start_time'] >= start_time) and (bb['start_time'] < stop_time-0.02): #(bb['end_time'] < end_time):
|
||||
if (bb["start_time"] >= start_time) and (
|
||||
bb["start_time"] < stop_time - 0.02
|
||||
): # (bb['end_time'] < end_time):
|
||||
detections.append(bb)
|
||||
|
||||
# create figure
|
||||
freq_scale = 1000 # turn Hz to kHz
|
||||
min_freq = params['min_freq']//freq_scale
|
||||
max_freq = params['max_freq']//freq_scale
|
||||
min_freq = params["min_freq"] // freq_scale
|
||||
max_freq = params["max_freq"] // freq_scale
|
||||
y_extent = [0, duration, min_freq, max_freq]
|
||||
|
||||
if hide_axis:
|
||||
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
||||
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
|
||||
ax.set_axis_off()
|
||||
fig.add_axes(ax)
|
||||
else:
|
||||
ax = plt.gca()
|
||||
|
||||
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=max_val)
|
||||
plt.imshow(
|
||||
spec,
|
||||
aspect="auto",
|
||||
cmap="plasma",
|
||||
extent=y_extent,
|
||||
vmin=0,
|
||||
vmax=max_val,
|
||||
)
|
||||
boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time)
|
||||
ax.add_collection(PatchCollection(boxes, match_original=True))
|
||||
plt.grid(False)
|
||||
|
||||
if plot_class_names:
|
||||
for ii, bb in enumerate(boxes):
|
||||
txt = ' '.join([sp[:3] for sp in detections_ip[ii]['class'].split(' ')])
|
||||
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
||||
txt = " ".join(
|
||||
[sp[:3] for sp in detections_ip[ii]["class"].split(" ")]
|
||||
)
|
||||
font_info = {
|
||||
"color": "white",
|
||||
"size": 10,
|
||||
"weight": "bold",
|
||||
"alpha": bb.get_alpha(),
|
||||
}
|
||||
y_pos = bb.get_xy()[1] + bb.get_height()
|
||||
if y_pos > (max_freq - 10):
|
||||
y_pos = max_freq - 10
|
||||
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
||||
|
||||
|
||||
def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title_text='', anns=None):
|
||||
def save_ann_spec(
|
||||
op_path,
|
||||
spec,
|
||||
min_freq,
|
||||
max_freq,
|
||||
duration,
|
||||
start_time,
|
||||
title_text="",
|
||||
anns=None,
|
||||
):
|
||||
# create figure and plot boxes
|
||||
freq_scale = 1000 # turn Hz to kHz
|
||||
min_freq = min_freq//freq_scale
|
||||
max_freq = max_freq//freq_scale
|
||||
min_freq = min_freq // freq_scale
|
||||
max_freq = max_freq // freq_scale
|
||||
y_extent = [0, duration, min_freq, max_freq]
|
||||
|
||||
plt.close('all')
|
||||
fig = plt.figure(0, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100)
|
||||
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=spec.max()*1.1)
|
||||
plt.close("all")
|
||||
fig = plt.figure(
|
||||
0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
|
||||
)
|
||||
plt.imshow(
|
||||
spec,
|
||||
aspect="auto",
|
||||
cmap="plasma",
|
||||
extent=y_extent,
|
||||
vmin=0,
|
||||
vmax=spec.max() * 1.1,
|
||||
)
|
||||
|
||||
plt.ylabel('Freq - kHz')
|
||||
plt.xlabel('Time - secs')
|
||||
if title_text != '':
|
||||
plt.ylabel("Freq - kHz")
|
||||
plt.xlabel("Time - secs")
|
||||
if title_text != "":
|
||||
plt.title(title_text)
|
||||
plt.tight_layout()
|
||||
|
||||
@ -66,122 +110,185 @@ def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title
|
||||
boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time)
|
||||
plt.gca().add_collection(PatchCollection(boxes, match_original=True))
|
||||
for ii, bb in enumerate(boxes):
|
||||
txt = ' '.join([sp[:3] for sp in anns[ii]['class'].split(' ')])
|
||||
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
||||
txt = " ".join([sp[:3] for sp in anns[ii]["class"].split(" ")])
|
||||
font_info = {
|
||||
"color": "white",
|
||||
"size": 10,
|
||||
"weight": "bold",
|
||||
"alpha": bb.get_alpha(),
|
||||
}
|
||||
y_pos = bb.get_xy()[1] + bb.get_height()
|
||||
if y_pos > (max_freq - 10):
|
||||
y_pos = max_freq - 10
|
||||
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
||||
|
||||
print('Saving figure to:', op_path)
|
||||
print("Saving figure to:", op_path)
|
||||
plt.savefig(op_path)
|
||||
|
||||
|
||||
def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False):
|
||||
def plot_pts(
|
||||
fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False
|
||||
):
|
||||
plt.figure(fig_id)
|
||||
un_class, labels = np.unique(class_names, return_inverse=True)
|
||||
un_labels = np.unique(labels)
|
||||
if un_labels.shape[0] > len(colors):
|
||||
colors = [plt.cm.jet(float(ii)/un_labels.shape[0]) for ii in un_labels]
|
||||
colors = [
|
||||
plt.cm.jet(float(ii) / un_labels.shape[0]) for ii in un_labels
|
||||
]
|
||||
|
||||
for ii, u in enumerate(un_labels):
|
||||
inds = np.where(labels==u)[0]
|
||||
plt.scatter(feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size)
|
||||
inds = np.where(labels == u)[0]
|
||||
plt.scatter(
|
||||
feats[inds, 0],
|
||||
feats[inds, 1],
|
||||
c=colors[ii],
|
||||
label=str(un_class[ii]),
|
||||
s=marker_size,
|
||||
)
|
||||
if plot_legend:
|
||||
plt.legend()
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
plt.title('downsampled features')
|
||||
plt.title("downsampled features")
|
||||
|
||||
|
||||
def plot_bounding_box_patch(pred, freq_scale, ecolor='w'):
|
||||
def plot_bounding_box_patch(pred, freq_scale, ecolor="w"):
|
||||
patch_collect = []
|
||||
for bb in range(len(pred['start_times'])):
|
||||
xx = pred['start_times'][bb]
|
||||
ww = pred['end_times'][bb] - pred['start_times'][bb]
|
||||
yy = pred['low_freqs'][bb] / freq_scale
|
||||
hh = (pred['high_freqs'][bb] - pred['low_freqs'][bb]) / freq_scale
|
||||
for bb in range(len(pred["start_times"])):
|
||||
xx = pred["start_times"][bb]
|
||||
ww = pred["end_times"][bb] - pred["start_times"][bb]
|
||||
yy = pred["low_freqs"][bb] / freq_scale
|
||||
hh = (pred["high_freqs"][bb] - pred["low_freqs"][bb]) / freq_scale
|
||||
|
||||
if 'det_probs' in pred.keys():
|
||||
alpha_val = pred['det_probs'][bb]
|
||||
if "det_probs" in pred.keys():
|
||||
alpha_val = pred["det_probs"][bb]
|
||||
else:
|
||||
alpha_val = 1.0
|
||||
patch_collect.append(patches.Rectangle((xx, yy), ww, hh, linewidth=1,
|
||||
edgecolor=ecolor, facecolor='none', alpha=alpha_val))
|
||||
patch_collect.append(
|
||||
patches.Rectangle(
|
||||
(xx, yy),
|
||||
ww,
|
||||
hh,
|
||||
linewidth=1,
|
||||
edgecolor=ecolor,
|
||||
facecolor="none",
|
||||
alpha=alpha_val,
|
||||
)
|
||||
)
|
||||
return patch_collect
|
||||
|
||||
|
||||
def plot_bounding_box_patch_ann(anns, freq_scale, start_time):
|
||||
patch_collect = []
|
||||
for aa in range(len(anns)):
|
||||
xx = anns[aa]['start_time'] - start_time
|
||||
ww = anns[aa]['end_time'] - anns[aa]['start_time']
|
||||
yy = anns[aa]['low_freq'] / freq_scale
|
||||
hh = (anns[aa]['high_freq'] - anns[aa]['low_freq']) / freq_scale
|
||||
if 'det_prob' in anns[aa]:
|
||||
alpha = anns[aa]['det_prob']
|
||||
xx = anns[aa]["start_time"] - start_time
|
||||
ww = anns[aa]["end_time"] - anns[aa]["start_time"]
|
||||
yy = anns[aa]["low_freq"] / freq_scale
|
||||
hh = (anns[aa]["high_freq"] - anns[aa]["low_freq"]) / freq_scale
|
||||
if "det_prob" in anns[aa]:
|
||||
alpha = anns[aa]["det_prob"]
|
||||
else:
|
||||
alpha = 1.0
|
||||
patch_collect.append(patches.Rectangle((xx,yy), ww, hh, linewidth=1,
|
||||
edgecolor='w', facecolor='none', alpha=alpha))
|
||||
patch_collect.append(
|
||||
patches.Rectangle(
|
||||
(xx, yy),
|
||||
ww,
|
||||
hh,
|
||||
linewidth=1,
|
||||
edgecolor="w",
|
||||
facecolor="none",
|
||||
alpha=alpha,
|
||||
)
|
||||
)
|
||||
return patch_collect
|
||||
|
||||
|
||||
def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
|
||||
op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True):
|
||||
def plot_spec(
|
||||
spec,
|
||||
sampling_rate,
|
||||
duration,
|
||||
gt,
|
||||
pred,
|
||||
params,
|
||||
plot_title,
|
||||
op_file_name,
|
||||
pred_2d_hm,
|
||||
plot_boxes=True,
|
||||
fixed_aspect=True,
|
||||
):
|
||||
|
||||
if fixed_aspect:
|
||||
# ouptut image will be this width irrespective of the duration of the audio file
|
||||
width = 12
|
||||
else:
|
||||
width = 12*duration
|
||||
width = 12 * duration
|
||||
|
||||
fig = plt.figure(1, figsize=(width, 8))
|
||||
ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h
|
||||
ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h
|
||||
ax1 = plt.axes([0.05, 0.33, 0.9, 0.30])
|
||||
ax2 = plt.axes([0.05, 0.01, 0.9, 0.30])
|
||||
|
||||
freq_scale = 1000 # turn Hz in kHz
|
||||
#duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
||||
y_extent = [0, duration, params['min_freq']//freq_scale, params['max_freq']//freq_scale]
|
||||
# duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
||||
y_extent = [
|
||||
0,
|
||||
duration,
|
||||
params["min_freq"] // freq_scale,
|
||||
params["max_freq"] // freq_scale,
|
||||
]
|
||||
|
||||
# plot gt boxes
|
||||
ax0.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent)
|
||||
ax0.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
|
||||
ax0.xaxis.set_ticklabels([])
|
||||
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
|
||||
ax0.text(0, params['min_freq']//freq_scale, 'Ground Truth', fontdict=font_info)
|
||||
font_info = {"color": "white", "size": 12, "weight": "bold"}
|
||||
ax0.text(
|
||||
0, params["min_freq"] // freq_scale, "Ground Truth", fontdict=font_info
|
||||
)
|
||||
|
||||
plt.grid(False)
|
||||
if plot_boxes:
|
||||
boxes = plot_bounding_box_patch(gt, freq_scale)
|
||||
ax0.add_collection(PatchCollection(boxes, match_original=True))
|
||||
for ii, bb in enumerate(boxes):
|
||||
class_id = int(gt['class_ids'][ii])
|
||||
class_id = int(gt["class_ids"][ii])
|
||||
if class_id < 0:
|
||||
txt = params['generic_class'][0]
|
||||
txt = params["generic_class"][0]
|
||||
else:
|
||||
txt = params['class_names_short'][class_id]
|
||||
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
||||
txt = params["class_names_short"][class_id]
|
||||
font_info = {
|
||||
"color": "white",
|
||||
"size": 10,
|
||||
"weight": "bold",
|
||||
"alpha": bb.get_alpha(),
|
||||
}
|
||||
y_pos = bb.get_xy()[1] + bb.get_height()
|
||||
ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
||||
|
||||
# plot predicted boxes
|
||||
ax1.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent)
|
||||
ax1.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
|
||||
ax1.xaxis.set_ticklabels([])
|
||||
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
|
||||
ax1.text(0, params['min_freq']//freq_scale, 'Prediction', fontdict=font_info)
|
||||
font_info = {"color": "white", "size": 12, "weight": "bold"}
|
||||
ax1.text(
|
||||
0, params["min_freq"] // freq_scale, "Prediction", fontdict=font_info
|
||||
)
|
||||
|
||||
plt.grid(False)
|
||||
if plot_boxes:
|
||||
boxes = plot_bounding_box_patch(pred, freq_scale)
|
||||
ax1.add_collection(PatchCollection(boxes, match_original=True))
|
||||
for ii, bb in enumerate(boxes):
|
||||
if pred['class_probs'].shape[0] > len(params['class_names_short']):
|
||||
class_id = pred['class_probs'][:-1, ii].argmax()
|
||||
if pred["class_probs"].shape[0] > len(params["class_names_short"]):
|
||||
class_id = pred["class_probs"][:-1, ii].argmax()
|
||||
else:
|
||||
class_id = pred['class_probs'][:, ii].argmax()
|
||||
txt = params['class_names_short'][class_id]
|
||||
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
||||
class_id = pred["class_probs"][:, ii].argmax()
|
||||
txt = params["class_names_short"][class_id]
|
||||
font_info = {
|
||||
"color": "white",
|
||||
"size": 10,
|
||||
"weight": "bold",
|
||||
"alpha": bb.get_alpha(),
|
||||
}
|
||||
y_pos = bb.get_xy()[1] + bb.get_height()
|
||||
ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
||||
|
||||
@ -190,10 +297,18 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
|
||||
min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min()
|
||||
max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max()
|
||||
|
||||
ax2.imshow(pred_2d_hm, aspect='auto', cmap='plasma', extent=y_extent, clim=[min_val, max_val])
|
||||
#ax2.xaxis.set_ticklabels([])
|
||||
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
|
||||
ax2.text(0, params['min_freq']//freq_scale, 'Heatmap', fontdict=font_info)
|
||||
ax2.imshow(
|
||||
pred_2d_hm,
|
||||
aspect="auto",
|
||||
cmap="plasma",
|
||||
extent=y_extent,
|
||||
clim=[min_val, max_val],
|
||||
)
|
||||
# ax2.xaxis.set_ticklabels([])
|
||||
font_info = {"color": "white", "size": 12, "weight": "bold"}
|
||||
ax2.text(
|
||||
0, params["min_freq"] // freq_scale, "Heatmap", fontdict=font_info
|
||||
)
|
||||
|
||||
plt.grid(False)
|
||||
|
||||
@ -204,107 +319,149 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
|
||||
plt.close(1)
|
||||
|
||||
|
||||
def plot_pr_curve(op_dir, plt_title, file_name, results, file_type='png', title_text=''):
|
||||
precision = results['precision']
|
||||
recall = results['recall']
|
||||
avg_prec = results['avg_prec']
|
||||
def plot_pr_curve(
|
||||
op_dir, plt_title, file_name, results, file_type="png", title_text=""
|
||||
):
|
||||
precision = results["precision"]
|
||||
recall = results["recall"]
|
||||
avg_prec = results["avg_prec"]
|
||||
|
||||
plt.figure(0, figsize=(10,8))
|
||||
plt.figure(0, figsize=(10, 8))
|
||||
plt.plot(recall, precision)
|
||||
plt.ylabel('Precision', fontsize=20)
|
||||
plt.xlabel('Recall', fontsize=20)
|
||||
if title_text != '':
|
||||
plt.title(title_text, fontdict={'fontsize': 28})
|
||||
plt.ylabel("Precision", fontsize=20)
|
||||
plt.xlabel("Recall", fontsize=20)
|
||||
if title_text != "":
|
||||
plt.title(title_text, fontdict={"fontsize": 28})
|
||||
else:
|
||||
plt.title(plt_title + ' {:.3f}\n'.format(avg_prec))
|
||||
plt.xlim(0,1.02)
|
||||
plt.ylim(0,1.02)
|
||||
plt.title(plt_title + " {:.3f}\n".format(avg_prec))
|
||||
plt.xlim(0, 1.02)
|
||||
plt.ylim(0, 1.02)
|
||||
plt.grid(True)
|
||||
plt.tight_layout()
|
||||
plt.savefig(op_dir + file_name + '.' + file_type)
|
||||
plt.savefig(op_dir + file_name + "." + file_type)
|
||||
plt.close(0)
|
||||
|
||||
|
||||
def plot_pr_curve_class(op_dir, plt_title, file_name, results, file_type='png', title_text=''):
|
||||
plt.figure(0, figsize=(10,8))
|
||||
plt.ylabel('Precision', fontsize=20)
|
||||
plt.xlabel('Recall', fontsize=20)
|
||||
plt.xlim(0,1.02)
|
||||
plt.ylim(0,1.02)
|
||||
def plot_pr_curve_class(
|
||||
op_dir, plt_title, file_name, results, file_type="png", title_text=""
|
||||
):
|
||||
plt.figure(0, figsize=(10, 8))
|
||||
plt.ylabel("Precision", fontsize=20)
|
||||
plt.xlabel("Recall", fontsize=20)
|
||||
plt.xlim(0, 1.02)
|
||||
plt.ylim(0, 1.02)
|
||||
plt.grid(True)
|
||||
linestyles = ['-', ':', '--']
|
||||
markers = ['o', 'v', '>', '^', '<', 's', 'P', 'X', '*']
|
||||
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
|
||||
linestyles = ["-", ":", "--"]
|
||||
markers = ["o", "v", ">", "^", "<", "s", "P", "X", "*"]
|
||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
|
||||
# plot the PR curves
|
||||
for ii, rr in enumerate(results['class_pr']):
|
||||
class_name = ' '.join([sp[:3] for sp in rr['name'].split(' ')])
|
||||
cur_color = colors[int(ii%10)]
|
||||
plt.plot(rr['recall'], rr['precision'], label=class_name, color=cur_color,
|
||||
linestyle=linestyles[int(ii//10)], lw=2.5)
|
||||
for ii, rr in enumerate(results["class_pr"]):
|
||||
class_name = " ".join([sp[:3] for sp in rr["name"].split(" ")])
|
||||
cur_color = colors[int(ii % 10)]
|
||||
plt.plot(
|
||||
rr["recall"],
|
||||
rr["precision"],
|
||||
label=class_name,
|
||||
color=cur_color,
|
||||
linestyle=linestyles[int(ii // 10)],
|
||||
lw=2.5,
|
||||
)
|
||||
|
||||
#print(class_name)
|
||||
# print(class_name)
|
||||
# plot the location of the confidence threshold values
|
||||
for jj, tt in enumerate(rr['thresholds']):
|
||||
ind = rr['thresholds_inds'][jj]
|
||||
for jj, tt in enumerate(rr["thresholds"]):
|
||||
ind = rr["thresholds_inds"][jj]
|
||||
if ind > -1:
|
||||
plt.plot(rr['recall'][ind], rr['precision'][ind], markers[jj],
|
||||
color=cur_color, ms=10)
|
||||
#print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3))
|
||||
plt.plot(
|
||||
rr["recall"][ind],
|
||||
rr["precision"][ind],
|
||||
markers[jj],
|
||||
color=cur_color,
|
||||
ms=10,
|
||||
)
|
||||
# print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3))
|
||||
|
||||
if title_text != '':
|
||||
plt.title(title_text, fontdict={'fontsize': 28})
|
||||
if title_text != "":
|
||||
plt.title(title_text, fontdict={"fontsize": 28})
|
||||
else:
|
||||
plt.title(plt_title + ' {:.3f}\n'.format(results['avg_prec_class']))
|
||||
plt.legend(loc='lower left', prop={'size': 14})
|
||||
plt.title(plt_title + " {:.3f}\n".format(results["avg_prec_class"]))
|
||||
plt.legend(loc="lower left", prop={"size": 14})
|
||||
plt.tight_layout()
|
||||
plt.savefig(op_dir + file_name + '.' + file_type)
|
||||
plt.savefig(op_dir + file_name + "." + file_type)
|
||||
plt.close(0)
|
||||
|
||||
|
||||
def plot_confusion_matrix(op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type='png', title_text=''):
|
||||
def plot_confusion_matrix(
|
||||
op_dir,
|
||||
op_file,
|
||||
gt,
|
||||
pred,
|
||||
file_acc,
|
||||
class_names_long,
|
||||
verbose=False,
|
||||
file_type="png",
|
||||
title_text="",
|
||||
):
|
||||
# shorten the class names for plotting
|
||||
class_names = []
|
||||
for cc in class_names_long:
|
||||
class_name_sm = ''.join([cc_sm[:3] + ' ' for cc_sm in cc.split(' ')])[:-1]
|
||||
class_name_sm = "".join([cc_sm[:3] + " " for cc_sm in cc.split(" ")])[
|
||||
:-1
|
||||
]
|
||||
class_names.append(class_name_sm)
|
||||
|
||||
num_classes = len(class_names)
|
||||
cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32)
|
||||
cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(
|
||||
np.float32
|
||||
)
|
||||
cm_norm = cm.sum(1)
|
||||
|
||||
valid_inds = np.where(cm_norm > 0)[0]
|
||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
cm[np.where(cm_norm ==- 0)[0], :] = np.nan
|
||||
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
||||
|
||||
if verbose:
|
||||
print('Per class accuracy:')
|
||||
print("Per class accuracy:")
|
||||
str_len = np.max([len(cc) for cc in class_names_long]) + 5
|
||||
accs = np.diag(cm)
|
||||
for ii, cc in enumerate(class_names_long):
|
||||
if np.isnan(accs[ii]):
|
||||
print(str(ii).ljust(5) + cc.ljust(str_len))
|
||||
else:
|
||||
print(str(ii).ljust(5) + cc.ljust(str_len) + '{:.2f}'.format(accs[ii]*100))
|
||||
print(
|
||||
str(ii).ljust(5)
|
||||
+ cc.ljust(str_len)
|
||||
+ "{:.2f}".format(accs[ii] * 100)
|
||||
)
|
||||
|
||||
plt.figure(0, figsize=(10,8))
|
||||
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma')
|
||||
plt.figure(0, figsize=(10, 8))
|
||||
plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
|
||||
plt.colorbar()
|
||||
plt.xticks(np.arange(cm.shape[1]), class_names, rotation='vertical')
|
||||
plt.xticks(np.arange(cm.shape[1]), class_names, rotation="vertical")
|
||||
plt.yticks(np.arange(cm.shape[0]), class_names)
|
||||
plt.xlabel('Predicted', fontsize=20)
|
||||
plt.ylabel('Ground Truth', fontsize=20)
|
||||
if title_text != '':
|
||||
plt.title(title_text, fontdict={'fontsize': 28})
|
||||
plt.xlabel("Predicted", fontsize=20)
|
||||
plt.ylabel("Ground Truth", fontsize=20)
|
||||
if title_text != "":
|
||||
plt.title(title_text, fontdict={"fontsize": 28})
|
||||
else:
|
||||
plt.title(op_file + ' {:.3f}\n'.format(file_acc))
|
||||
plt.title(op_file + " {:.3f}\n".format(file_acc))
|
||||
plt.tight_layout()
|
||||
plt.savefig(op_dir + op_file + '.' + file_type)
|
||||
plt.close('all')
|
||||
plt.savefig(op_dir + op_file + "." + file_type)
|
||||
plt.close("all")
|
||||
|
||||
|
||||
class LossPlotter(object):
|
||||
def __init__(self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False):
|
||||
def __init__(
|
||||
self,
|
||||
op_file_name,
|
||||
duration,
|
||||
labels,
|
||||
ylim,
|
||||
class_names,
|
||||
axis_labels=None,
|
||||
logy=False,
|
||||
):
|
||||
self.reset()
|
||||
self.op_file_name = op_file_name
|
||||
self.duration = duration # length of x axis
|
||||
@ -327,11 +484,16 @@ class LossPlotter(object):
|
||||
self.save_confusion_matrix(gt, pred)
|
||||
|
||||
def save_plot(self):
|
||||
linestyles = ['-', ':', '--']
|
||||
plt.figure(0, figsize=(8,5))
|
||||
linestyles = ["-", ":", "--"]
|
||||
plt.figure(0, figsize=(8, 5))
|
||||
for ii in range(len(self.vals[0])):
|
||||
l_vals = [vv[ii] for vv in self.vals]
|
||||
plt.plot(self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii//10)])
|
||||
plt.plot(
|
||||
self.epochs,
|
||||
l_vals,
|
||||
label=self.labels[ii],
|
||||
linestyle=linestyles[int(ii // 10)],
|
||||
)
|
||||
plt.xlim(0, np.maximum(self.duration, len(self.vals)))
|
||||
if self.ylim is not None:
|
||||
plt.ylim(self.ylim[0], self.ylim[1])
|
||||
@ -339,33 +501,41 @@ class LossPlotter(object):
|
||||
plt.xlabel(self.axis_labels[0])
|
||||
plt.ylabel(self.axis_labels[1])
|
||||
if self.logy:
|
||||
plt.gca().set_yscale('log')
|
||||
plt.gca().set_yscale("log")
|
||||
plt.grid(True)
|
||||
plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.0)
|
||||
plt.legend(
|
||||
bbox_to_anchor=(1.01, 1), loc="upper left", borderaxespad=0.0
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.op_file_name)
|
||||
plt.close(0)
|
||||
|
||||
def save_json(self):
|
||||
data = {}
|
||||
data['epochs'] = self.epochs
|
||||
data["epochs"] = self.epochs
|
||||
for ii in range(len(self.vals[0])):
|
||||
data[self.labels[ii]] = [round(vv[ii],4) for vv in self.vals]
|
||||
with open(self.op_file_name[:-4] + '.json', 'w') as da:
|
||||
data[self.labels[ii]] = [round(vv[ii], 4) for vv in self.vals]
|
||||
with open(self.op_file_name[:-4] + ".json", "w") as da:
|
||||
json.dump(data, da, indent=2)
|
||||
|
||||
def save_confusion_matrix(self, gt, pred):
|
||||
plt.figure(0)
|
||||
cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(np.float32)
|
||||
cm = confusion_matrix(
|
||||
gt, pred, labels=np.arange(len(self.class_names))
|
||||
).astype(np.float32)
|
||||
cm_norm = cm.sum(1)
|
||||
valid_inds = np.where(cm_norm > 0)[0]
|
||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma')
|
||||
cm[valid_inds, :] = (
|
||||
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
)
|
||||
plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
|
||||
plt.colorbar()
|
||||
plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation='vertical')
|
||||
plt.xticks(
|
||||
np.arange(cm.shape[1]), self.class_names, rotation="vertical"
|
||||
)
|
||||
plt.yticks(np.arange(cm.shape[0]), self.class_names)
|
||||
plt.xlabel('Predicted')
|
||||
plt.ylabel('Ground Truth')
|
||||
plt.xlabel("Predicted")
|
||||
plt.ylabel("Ground Truth")
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.op_file_name[:-4] + '_cm.png')
|
||||
plt.savefig(self.op_file_name[:-4] + "_cm.png")
|
||||
plt.close(0)
|
||||
|
@ -1,19 +1,46 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib import patches
|
||||
from sklearn.svm import LinearSVC
|
||||
from matplotlib.axes._axes import _log as matplotlib_axes_logger
|
||||
matplotlib_axes_logger.setLevel('ERROR')
|
||||
from sklearn.svm import LinearSVC
|
||||
|
||||
matplotlib_axes_logger.setLevel("ERROR")
|
||||
|
||||
|
||||
colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4',
|
||||
'#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff',
|
||||
'#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1',
|
||||
'#000075', '#a9a9a9']
|
||||
colors = [
|
||||
"#e6194B",
|
||||
"#3cb44b",
|
||||
"#ffe119",
|
||||
"#4363d8",
|
||||
"#f58231",
|
||||
"#911eb4",
|
||||
"#42d4f4",
|
||||
"#f032e6",
|
||||
"#bfef45",
|
||||
"#fabebe",
|
||||
"#469990",
|
||||
"#e6beff",
|
||||
"#9A6324",
|
||||
"#fffac8",
|
||||
"#800000",
|
||||
"#aaffc3",
|
||||
"#808000",
|
||||
"#ffd8b1",
|
||||
"#000075",
|
||||
"#a9a9a9",
|
||||
]
|
||||
|
||||
|
||||
class InteractivePlotter:
|
||||
def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training):
|
||||
def __init__(
|
||||
self,
|
||||
feats_ds,
|
||||
feats,
|
||||
spec_slices,
|
||||
call_info,
|
||||
freq_lims,
|
||||
allow_training,
|
||||
):
|
||||
"""
|
||||
Plots 2D low dimensional features on left and corresponding spectgrams on
|
||||
the right.
|
||||
@ -24,78 +51,123 @@ class InteractivePlotter:
|
||||
|
||||
self.spec_slices = spec_slices
|
||||
self.call_info = call_info
|
||||
#_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
|
||||
# _, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
|
||||
self.labels = np.zeros(len(call_info), dtype=np.int)
|
||||
self.annotated = np.zeros(self.labels.shape[0], dtype=np.int) # can populate this with 1's where we have labels
|
||||
self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))]
|
||||
self.annotated = np.zeros(
|
||||
self.labels.shape[0], dtype=np.int
|
||||
) # can populate this with 1's where we have labels
|
||||
self.labels_cols = [
|
||||
colors[self.labels[ii]] for ii in range(len(self.labels))
|
||||
]
|
||||
self.freq_lims = freq_lims
|
||||
|
||||
self.allow_training = allow_training
|
||||
self.pt_size = 5.0
|
||||
self.spec_pad = 0.2 # this much padding has been applied to the spec slices
|
||||
self.spec_pad = (
|
||||
0.2 # this much padding has been applied to the spec slices
|
||||
)
|
||||
self.fig_width = 12
|
||||
self.fig_height = 8
|
||||
|
||||
self.current_id = 0
|
||||
max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices])
|
||||
self.max_width = self.spec_slices[max_ind].shape[1]
|
||||
self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width))
|
||||
|
||||
self.blank_spec = np.zeros(
|
||||
(self.spec_slices[0].shape[0], self.max_width)
|
||||
)
|
||||
|
||||
def plot(self, fig_id):
|
||||
self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height),
|
||||
gridspec_kw={'width_ratios': [2, 1]})
|
||||
self.fig, self.ax = plt.subplots(
|
||||
nrows=1,
|
||||
ncols=2,
|
||||
num=fig_id,
|
||||
figsize=(self.fig_width, self.fig_height),
|
||||
gridspec_kw={"width_ratios": [2, 1]},
|
||||
)
|
||||
plt.tight_layout()
|
||||
|
||||
# plot 2D TNSE features
|
||||
self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
|
||||
c=self.labels_cols, s=self.pt_size, picker=5)
|
||||
self.ax[0].set_title('TSNE of Call Features')
|
||||
self.low_dim_plt = self.ax[0].scatter(
|
||||
self.feats_ds[:, 0],
|
||||
self.feats_ds[:, 1],
|
||||
c=self.labels_cols,
|
||||
s=self.pt_size,
|
||||
picker=5,
|
||||
)
|
||||
self.ax[0].set_title("TSNE of Call Features")
|
||||
self.ax[0].set_xticks([])
|
||||
self.ax[0].set_yticks([])
|
||||
|
||||
# plot clip from spectrogram
|
||||
spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1])
|
||||
self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto')
|
||||
spec_min_max = (
|
||||
0,
|
||||
self.blank_spec.shape[1],
|
||||
self.freq_lims[0],
|
||||
self.freq_lims[1],
|
||||
)
|
||||
self.ax[1].imshow(
|
||||
self.blank_spec, extent=spec_min_max, cmap="plasma", aspect="auto"
|
||||
)
|
||||
self.spec_im = self.ax[1].get_images()[0]
|
||||
self.ax[1].set_title('Spectrogram')
|
||||
self.ax[1].grid(color='w', linewidth=0.5)
|
||||
self.ax[1].set_title("Spectrogram")
|
||||
self.ax[1].grid(color="w", linewidth=0.5)
|
||||
self.ax[1].set_xticks([])
|
||||
self.ax[1].set_ylabel('kHz')
|
||||
self.ax[1].set_ylabel("kHz")
|
||||
|
||||
bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False)
|
||||
bbox_orig = patches.Rectangle(
|
||||
(0, 0), 0, 0, edgecolor="w", linewidth=0, fill=False
|
||||
)
|
||||
self.ax[1].add_patch(bbox_orig)
|
||||
|
||||
self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points',
|
||||
bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->'))
|
||||
self.annot = self.ax[0].annotate(
|
||||
"",
|
||||
xy=(0, 0),
|
||||
xytext=(20, 20),
|
||||
textcoords="offset points",
|
||||
bbox=dict(boxstyle="round", fc="w"),
|
||||
arrowprops=dict(arrowstyle="->"),
|
||||
)
|
||||
self.annot.set_visible(False)
|
||||
|
||||
self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover)
|
||||
self.fig.canvas.mpl_connect('key_press_event', self.key_press)
|
||||
|
||||
self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_hover)
|
||||
self.fig.canvas.mpl_connect("key_press_event", self.key_press)
|
||||
|
||||
def mouse_hover(self, event):
|
||||
vis = self.annot.get_visible()
|
||||
if event.inaxes == self.ax[0]:
|
||||
cont, ind = self.low_dim_plt.contains(event)
|
||||
if cont:
|
||||
self.current_id = ind['ind'][0]
|
||||
self.current_id = ind["ind"][0]
|
||||
|
||||
# copy spec into full window - probably a better way of doing this
|
||||
new_spec = self.blank_spec.copy()
|
||||
w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2
|
||||
new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id]
|
||||
w_diff = (
|
||||
self.blank_spec.shape[1]
|
||||
- self.spec_slices[self.current_id].shape[1]
|
||||
) // 2
|
||||
new_spec[
|
||||
:,
|
||||
w_diff : self.spec_slices[self.current_id].shape[1]
|
||||
+ w_diff,
|
||||
] = self.spec_slices[self.current_id]
|
||||
self.spec_im.set_data(new_spec)
|
||||
self.spec_im.set_clim(vmin=0, vmax=new_spec.max())
|
||||
|
||||
# draw bounding box around call
|
||||
self.ax[1].patches[0].remove()
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad)
|
||||
xx = w_diff + self.spec_pad*spec_width_orig
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
||||
1.0 + 2.0 * self.spec_pad
|
||||
)
|
||||
xx = w_diff + self.spec_pad * spec_width_orig
|
||||
ww = spec_width_orig
|
||||
yy = self.call_info[self.current_id]['low_freq']/1000
|
||||
hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000
|
||||
bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False)
|
||||
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
||||
hh = (
|
||||
self.call_info[self.current_id]["high_freq"]
|
||||
- self.call_info[self.current_id]["low_freq"]
|
||||
) / 1000
|
||||
bbox = patches.Rectangle(
|
||||
(xx, yy), ww, hh, edgecolor="r", linewidth=0.5, fill=False
|
||||
)
|
||||
self.ax[1].add_patch(bbox)
|
||||
|
||||
# update annotation arrow
|
||||
@ -104,38 +176,52 @@ class InteractivePlotter:
|
||||
self.annot.set_visible(True)
|
||||
|
||||
# write call info
|
||||
info_str = self.call_info[self.current_id]['file_name'] + ', time=' \
|
||||
+ str(round(self.call_info[self.current_id]['start_time'],3)) \
|
||||
+ ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3))
|
||||
info_str = (
|
||||
self.call_info[self.current_id]["file_name"]
|
||||
+ ", time="
|
||||
+ str(
|
||||
round(self.call_info[self.current_id]["start_time"], 3)
|
||||
)
|
||||
+ ", prob="
|
||||
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
||||
)
|
||||
self.ax[0].set_xlabel(info_str)
|
||||
|
||||
# redraw
|
||||
self.fig.canvas.draw_idle()
|
||||
|
||||
|
||||
def key_press(self, event):
|
||||
if event.key.isdigit():
|
||||
self.labels_cols[self.current_id] = colors[int(event.key)]
|
||||
self.labels[self.current_id] = int(event.key)
|
||||
self.annotated[self.current_id] = 1
|
||||
elif event.key == 'enter' and self.allow_training:
|
||||
elif event.key == "enter" and self.allow_training:
|
||||
self.train_classifier()
|
||||
elif event.key == 'x' and self.allow_training:
|
||||
elif event.key == "x" and self.allow_training:
|
||||
self.get_classifier_params()
|
||||
|
||||
self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
|
||||
c=self.labels_cols, s=self.pt_size)
|
||||
self.ax[0].scatter(
|
||||
self.feats_ds[:, 0],
|
||||
self.feats_ds[:, 1],
|
||||
c=self.labels_cols,
|
||||
s=self.pt_size,
|
||||
)
|
||||
self.fig.canvas.draw_idle()
|
||||
|
||||
|
||||
def train_classifier(self):
|
||||
# TODO maybe it's better to classify in 2D space - but then can't be linear ...
|
||||
inds = np.where(self.annotated == 1)[0]
|
||||
labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True)
|
||||
|
||||
if labs_un.shape[0] > 1: # needs at least 2 classes
|
||||
self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001,
|
||||
intercept_scaling=1.0, max_iter=2000)
|
||||
self.clf = LinearSVC(
|
||||
C=1.0,
|
||||
penalty="l2",
|
||||
loss="squared_hinge",
|
||||
tol=0.0001,
|
||||
intercept_scaling=1.0,
|
||||
max_iter=2000,
|
||||
)
|
||||
|
||||
self.clf.fit(self.feats[inds, :], self.labels[inds])
|
||||
|
||||
@ -145,14 +231,13 @@ class InteractivePlotter:
|
||||
for ii in inds_unlab:
|
||||
self.labels_cols[ii] = colors[self.labels[ii]]
|
||||
else:
|
||||
print('Not enough data - please label more classes.')
|
||||
|
||||
print("Not enough data - please label more classes.")
|
||||
|
||||
def get_classifier_params(self):
|
||||
res = {}
|
||||
if self.clf is None:
|
||||
print('Model not trained!')
|
||||
print("Model not trained!")
|
||||
else:
|
||||
res['weights'] = self.clf.coef_.astype(np.float32)
|
||||
res['biases'] = self.clf.intercept_.astype(np.float32)
|
||||
res["weights"] = self.clf.coef_.astype(np.float32)
|
||||
res["biases"] = self.clf.intercept_.astype(np.float32)
|
||||
return res
|
||||
|
@ -8,23 +8,25 @@ Functions
|
||||
`write`: Write a numpy array as a WAV file.
|
||||
|
||||
"""
|
||||
from __future__ import division, print_function, absolute_import
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import sys
|
||||
import numpy
|
||||
import struct
|
||||
import warnings
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
class WavFileWarning(UserWarning):
|
||||
pass
|
||||
|
||||
|
||||
_big_endian = False
|
||||
|
||||
WAVE_FORMAT_PCM = 0x0001
|
||||
WAVE_FORMAT_IEEE_FLOAT = 0x0003
|
||||
WAVE_FORMAT_EXTENSIBLE = 0xfffe
|
||||
WAVE_FORMAT_EXTENSIBLE = 0xFFFE
|
||||
KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT)
|
||||
|
||||
# assumes file pointer is immediately
|
||||
@ -33,10 +35,10 @@ KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT)
|
||||
|
||||
def _read_fmt_chunk(fid):
|
||||
if _big_endian:
|
||||
fmt = '>'
|
||||
fmt = ">"
|
||||
else:
|
||||
fmt = '<'
|
||||
res = struct.unpack(fmt+'iHHIIHH',fid.read(20))
|
||||
fmt = "<"
|
||||
res = struct.unpack(fmt + "iHHIIHH", fid.read(20))
|
||||
size, comp, noc, rate, sbytes, ba, bits = res
|
||||
if comp not in KNOWN_WAVE_FORMATS or size > 16:
|
||||
comp = WAVE_FORMAT_PCM
|
||||
@ -51,41 +53,42 @@ def _read_fmt_chunk(fid):
|
||||
# after the 'data' id
|
||||
def _read_data_chunk(fid, comp, noc, bits, mmap=False):
|
||||
if _big_endian:
|
||||
fmt = '>i'
|
||||
fmt = ">i"
|
||||
else:
|
||||
fmt = '<i'
|
||||
size = struct.unpack(fmt,fid.read(4))[0]
|
||||
fmt = "<i"
|
||||
size = struct.unpack(fmt, fid.read(4))[0]
|
||||
|
||||
bytes = bits//8
|
||||
bytes = bits // 8
|
||||
if bits == 8:
|
||||
dtype = 'u1'
|
||||
dtype = "u1"
|
||||
else:
|
||||
if _big_endian:
|
||||
dtype = '>'
|
||||
dtype = ">"
|
||||
else:
|
||||
dtype = '<'
|
||||
dtype = "<"
|
||||
if comp == 1:
|
||||
dtype += 'i%d' % bytes
|
||||
dtype += "i%d" % bytes
|
||||
else:
|
||||
dtype += 'f%d' % bytes
|
||||
dtype += "f%d" % bytes
|
||||
if not mmap:
|
||||
data = numpy.fromstring(fid.read(size), dtype=dtype)
|
||||
else:
|
||||
start = fid.tell()
|
||||
data = numpy.memmap(fid, dtype=dtype, mode='c', offset=start,
|
||||
shape=(size//bytes,))
|
||||
data = numpy.memmap(
|
||||
fid, dtype=dtype, mode="c", offset=start, shape=(size // bytes,)
|
||||
)
|
||||
fid.seek(start + size)
|
||||
|
||||
if noc > 1:
|
||||
data = data.reshape(-1,noc)
|
||||
data = data.reshape(-1, noc)
|
||||
return data
|
||||
|
||||
|
||||
def _skip_unknown_chunk(fid):
|
||||
if _big_endian:
|
||||
fmt = '>i'
|
||||
fmt = ">i"
|
||||
else:
|
||||
fmt = '<i'
|
||||
fmt = "<i"
|
||||
|
||||
data = fid.read(4)
|
||||
size = struct.unpack(fmt, data)[0]
|
||||
@ -95,22 +98,23 @@ def _skip_unknown_chunk(fid):
|
||||
def _read_riff_chunk(fid):
|
||||
global _big_endian
|
||||
str1 = fid.read(4)
|
||||
if str1 == b'RIFX':
|
||||
if str1 == b"RIFX":
|
||||
_big_endian = True
|
||||
elif str1 != b'RIFF':
|
||||
elif str1 != b"RIFF":
|
||||
raise ValueError("Not a WAV file.")
|
||||
if _big_endian:
|
||||
fmt = '>I'
|
||||
fmt = ">I"
|
||||
else:
|
||||
fmt = '<I'
|
||||
fmt = "<I"
|
||||
fsize = struct.unpack(fmt, fid.read(4))[0] + 8
|
||||
str2 = fid.read(4)
|
||||
if (str2 != b'WAVE'):
|
||||
if str2 != b"WAVE":
|
||||
raise ValueError("Not a WAV file.")
|
||||
if str1 == b'RIFX':
|
||||
if str1 == b"RIFX":
|
||||
_big_endian = True
|
||||
return fsize
|
||||
|
||||
|
||||
# open a wave-file
|
||||
|
||||
|
||||
@ -145,11 +149,11 @@ def read(filename, mmap=False):
|
||||
data-type determined from the file.
|
||||
|
||||
"""
|
||||
if hasattr(filename,'read'):
|
||||
if hasattr(filename, "read"):
|
||||
fid = filename
|
||||
mmap = False
|
||||
else:
|
||||
fid = open(filename, 'rb')
|
||||
fid = open(filename, "rb")
|
||||
|
||||
try:
|
||||
|
||||
@ -169,16 +173,16 @@ def read(filename, mmap=False):
|
||||
noc = 1
|
||||
bits = 8
|
||||
comp = WAVE_FORMAT_PCM
|
||||
while (fid.tell() < fsize):
|
||||
while fid.tell() < fsize:
|
||||
# read the next chunk
|
||||
chunk_id = fid.read(4)
|
||||
if chunk_id == b'fmt ':
|
||||
if chunk_id == b"fmt ":
|
||||
size, comp, noc, rate, sbytes, ba, bits = _read_fmt_chunk(fid)
|
||||
elif chunk_id == b'fact':
|
||||
elif chunk_id == b"fact":
|
||||
_skip_unknown_chunk(fid)
|
||||
elif chunk_id == b'data':
|
||||
elif chunk_id == b"data":
|
||||
data = _read_data_chunk(fid, comp, noc, bits, mmap=mmap)
|
||||
elif chunk_id == b'LIST':
|
||||
elif chunk_id == b"LIST":
|
||||
# Someday this could be handled properly but for now skip it
|
||||
_skip_unknown_chunk(fid)
|
||||
|
||||
@ -187,13 +191,14 @@ def read(filename, mmap=False):
|
||||
# warnings.warn("Chunk (non-data) not understood, skipping it.", WavFileWarning)
|
||||
# _skip_unknown_chunk(fid)
|
||||
finally:
|
||||
if not hasattr(filename,'read'):
|
||||
if not hasattr(filename, "read"):
|
||||
fid.close()
|
||||
else:
|
||||
fid.seek(0)
|
||||
|
||||
return rate, data
|
||||
|
||||
|
||||
# Write a wave-file
|
||||
# sample rate, data
|
||||
|
||||
@ -221,26 +226,30 @@ def write(filename, rate, data):
|
||||
(Nsamples, Nchannels).
|
||||
|
||||
"""
|
||||
if hasattr(filename, 'write'):
|
||||
if hasattr(filename, "write"):
|
||||
fid = filename
|
||||
else:
|
||||
fid = open(filename, 'wb')
|
||||
fid = open(filename, "wb")
|
||||
|
||||
try:
|
||||
# kind of numeric data in the numpy array
|
||||
dkind = data.dtype.kind
|
||||
if not (dkind == 'i' or dkind == 'f' or (dkind == 'u' and data.dtype.itemsize == 1)):
|
||||
if not (
|
||||
dkind == "i"
|
||||
or dkind == "f"
|
||||
or (dkind == "u" and data.dtype.itemsize == 1)
|
||||
):
|
||||
raise ValueError("Unsupported data type '%s'" % data.dtype)
|
||||
|
||||
# wav header stuff
|
||||
# http://soundfile.sapp.org/doc/WaveFormat/
|
||||
fid.write(b'RIFF')
|
||||
fid.write(b"RIFF")
|
||||
# placeholder for chunk size (updated later)
|
||||
fid.write(b'\x00\x00\x00\x00')
|
||||
fid.write(b'WAVE')
|
||||
fid.write(b"\x00\x00\x00\x00")
|
||||
fid.write(b"WAVE")
|
||||
# fmt chunk
|
||||
fid.write(b'fmt ')
|
||||
if dkind == 'f':
|
||||
fid.write(b"fmt ")
|
||||
if dkind == "f":
|
||||
# comp stands for compression. PCM = 1
|
||||
comp = 3
|
||||
else:
|
||||
@ -253,7 +262,7 @@ def write(filename, rate, data):
|
||||
bits = data.dtype.itemsize * 8
|
||||
# number of bytes per second, at the specified sampling rate rate,
|
||||
# bits per sample and number of channels (just needed for wav header)
|
||||
sbytes = rate*(bits // 8)*noc
|
||||
sbytes = rate * (bits // 8) * noc
|
||||
# number of bytes per sample
|
||||
ba = noc * (bits // 8)
|
||||
|
||||
@ -261,11 +270,15 @@ def write(filename, rate, data):
|
||||
# Write the data (16, comp, noc, etc) in the correct binary format
|
||||
# for the wav header. the string format (first arg) specifies how many bytes for each
|
||||
# value.
|
||||
fid.write(struct.pack('<ihHIIHH', 16, comp, noc, rate, sbytes, ba, bits))
|
||||
fid.write(
|
||||
struct.pack("<ihHIIHH", 16, comp, noc, rate, sbytes, ba, bits)
|
||||
)
|
||||
# data chunk: the word 'data' followed by the size followed by the actual data
|
||||
fid.write(b'data')
|
||||
fid.write(struct.pack('<i', data.nbytes))
|
||||
if data.dtype.byteorder == '>' or (data.dtype.byteorder == '=' and sys.byteorder == 'big'):
|
||||
fid.write(b"data")
|
||||
fid.write(struct.pack("<i", data.nbytes))
|
||||
if data.dtype.byteorder == ">" or (
|
||||
data.dtype.byteorder == "=" and sys.byteorder == "big"
|
||||
):
|
||||
data = data.byteswap()
|
||||
_array_tofile(fid, data)
|
||||
|
||||
@ -273,19 +286,22 @@ def write(filename, rate, data):
|
||||
# position at start of the file (replacing the 4 bytes of zeros)
|
||||
size = fid.tell()
|
||||
fid.seek(4)
|
||||
fid.write(struct.pack('<i', size-8))
|
||||
fid.write(struct.pack("<i", size - 8))
|
||||
|
||||
finally:
|
||||
if not hasattr(filename,'write'):
|
||||
if not hasattr(filename, "write"):
|
||||
fid.close()
|
||||
else:
|
||||
fid.seek(0)
|
||||
|
||||
|
||||
if sys.version_info[0] >= 3:
|
||||
|
||||
def _array_tofile(fid, data):
|
||||
# ravel gives a c-contiguous buffer
|
||||
fid.write(data.ravel().view('b').data)
|
||||
fid.write(data.ravel().view("b").data)
|
||||
|
||||
else:
|
||||
|
||||
def _array_tofile(fid, data):
|
||||
fid.write(data.tostring())
|
||||
|
@ -56,9 +56,9 @@
|
||||
"source": [
|
||||
"# setup the arguments\n",
|
||||
"args = du.get_default_bd_args()\n",
|
||||
"args['detection_threshold'] = 0.3\n",
|
||||
"args['time_expansion_factor'] = 1\n",
|
||||
"args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'\n",
|
||||
"args[\"detection_threshold\"] = 0.3\n",
|
||||
"args[\"time_expansion_factor\"] = 1\n",
|
||||
"args[\"model_path\"] = \"models/Net2DFast_UK_same.pth.tar\"\n",
|
||||
"max_duration = 2.0"
|
||||
]
|
||||
},
|
||||
@ -69,7 +69,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load the model\n",
|
||||
"model, params = du.load_model(args['model_path'])"
|
||||
"model, params = du.load_model(args[\"model_path\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -86,13 +86,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# choose an audio file\n",
|
||||
"audio_file = 'example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav'\n",
|
||||
"audio_file = \"example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav\"\n",
|
||||
"\n",
|
||||
"# the following lines are only needed in Colab\n",
|
||||
"# alternatively you can upload your own file\n",
|
||||
"#from google.colab import files\n",
|
||||
"#uploaded = files.upload()\n",
|
||||
"#audio_file = list(uploaded.keys())[0]"
|
||||
"# from google.colab import files\n",
|
||||
"# uploaded = files.upload()\n",
|
||||
"# audio_file = list(uploaded.keys())[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -102,7 +102,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# run the model\n",
|
||||
"results = du.process_file(audio_file, model, params, args, max_duration=max_duration)"
|
||||
"results = du.process_file(\n",
|
||||
" audio_file, model, params, args, max_duration=max_duration\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -144,13 +146,17 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# print summary info for the individual detections \n",
|
||||
"print('Results for ' + results['pred_dict']['id'])\n",
|
||||
"print('{} calls detected\\n'.format(len(results['pred_dict']['annotation'])))\n",
|
||||
"# print summary info for the individual detections\n",
|
||||
"print(\"Results for \" + results[\"pred_dict\"][\"id\"])\n",
|
||||
"print(\"{} calls detected\\n\".format(len(results[\"pred_dict\"][\"annotation\"])))\n",
|
||||
"\n",
|
||||
"print('time\\tprob\\tlfreq\\tspecies_name')\n",
|
||||
"for ann in results['pred_dict']['annotation']:\n",
|
||||
" print('{}\\t{}\\t{}\\t{}'.format(ann['start_time'], ann['class_prob'], ann['low_freq'], ann['class']))"
|
||||
"print(\"time\\tprob\\tlfreq\\tspecies_name\")\n",
|
||||
"for ann in results[\"pred_dict\"][\"annotation\"]:\n",
|
||||
" print(\n",
|
||||
" \"{}\\t{}\\t{}\\t{}\".format(\n",
|
||||
" ann[\"start_time\"], ann[\"class_prob\"], ann[\"low_freq\"], ann[\"class\"]\n",
|
||||
" )\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -174,10 +180,16 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# read the audio file \n",
|
||||
"sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)\n",
|
||||
"# read the audio file\n",
|
||||
"sampling_rate, audio = au.load_audio_file(\n",
|
||||
" audio_file,\n",
|
||||
" args[\"time_expansion_factor\"],\n",
|
||||
" params[\"target_samp_rate\"],\n",
|
||||
" params[\"scale_raw_audio\"],\n",
|
||||
" max_duration=max_duration,\n",
|
||||
")\n",
|
||||
"duration = audio.shape[0] / sampling_rate\n",
|
||||
"print('File duration: {} seconds'.format(duration))"
|
||||
"print(\"File duration: {} seconds\".format(duration))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -187,7 +199,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate spectrogram for visualization\n",
|
||||
"spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)"
|
||||
"spec, spec_viz = au.generate_spectrogram(\n",
|
||||
" audio, sampling_rate, params, True, False\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -210,12 +224,33 @@
|
||||
"# display the detections on top of the spectrogram\n",
|
||||
"# note, if the audio file is very long, this image will be very large - best to crop the audio first\n",
|
||||
"start_time = 0.0\n",
|
||||
"detections = [ann for ann in results['pred_dict']['annotation']]\n",
|
||||
"fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)\n",
|
||||
"spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])\n",
|
||||
"viz.create_box_image(spec, fig, detections, start_time, start_time+spec_duration, spec_duration, params, spec.max()*1.1, False, True)\n",
|
||||
"plt.ylabel('Freq - kHz')\n",
|
||||
"plt.xlabel('Time - secs')\n",
|
||||
"detections = [ann for ann in results[\"pred_dict\"][\"annotation\"]]\n",
|
||||
"fig = plt.figure(\n",
|
||||
" 1,\n",
|
||||
" figsize=(spec.shape[1] / 100, spec.shape[0] / 100),\n",
|
||||
" dpi=100,\n",
|
||||
" frameon=False,\n",
|
||||
")\n",
|
||||
"spec_duration = au.x_coords_to_time(\n",
|
||||
" spec.shape[1],\n",
|
||||
" sampling_rate,\n",
|
||||
" params[\"fft_win_length\"],\n",
|
||||
" params[\"fft_overlap\"],\n",
|
||||
")\n",
|
||||
"viz.create_box_image(\n",
|
||||
" spec,\n",
|
||||
" fig,\n",
|
||||
" detections,\n",
|
||||
" start_time,\n",
|
||||
" start_time + spec_duration,\n",
|
||||
" spec_duration,\n",
|
||||
" params,\n",
|
||||
" spec.max() * 1.1,\n",
|
||||
" False,\n",
|
||||
" True,\n",
|
||||
")\n",
|
||||
"plt.ylabel(\"Freq - kHz\")\n",
|
||||
"plt.xlabel(\"Time - secs\")\n",
|
||||
"plt.title(os.path.basename(audio_file))\n",
|
||||
"plt.show()"
|
||||
]
|
||||
|
79
pyproject.toml
Normal file
79
pyproject.toml
Normal file
@ -0,0 +1,79 @@
|
||||
[tool.pdm]
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.2.2",
|
||||
]
|
||||
|
||||
[project]
|
||||
name = "batdetect2"
|
||||
version = "0.2.0"
|
||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||
authors = [
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
|
||||
]
|
||||
dependencies = [
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"torch<2",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"click",
|
||||
]
|
||||
requires-python = ">=3.8,<3.11"
|
||||
readme = "README.md"
|
||||
license = { text = "CC-by-nc-4" }
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||
]
|
||||
keywords = [
|
||||
"bat",
|
||||
"echolocation",
|
||||
"deep learning",
|
||||
"audio",
|
||||
"machine learning",
|
||||
"classification",
|
||||
"detection",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["pdm-pep517>=1.0.0"]
|
||||
build-backend = "pdm.pep517.api"
|
||||
|
||||
[project.scripts]
|
||||
batdetect2 = "bat_detect.cli:cli"
|
||||
|
||||
[tool.black]
|
||||
line-length = 80
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"librosa",
|
||||
"pandas",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pylsp-mypy]
|
||||
enabled = false
|
||||
live_mode = true
|
||||
strict = true
|
||||
|
||||
[tool.pyright]
|
||||
include = [
|
||||
"bat_detect",
|
||||
"tests",
|
||||
]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
@ -7,3 +7,4 @@ scipy==1.9.3
|
||||
torch==1.13.0
|
||||
torchaudio==0.13.0
|
||||
torchvision==0.14.0
|
||||
click
|
||||
|
@ -1,67 +1,5 @@
|
||||
import os
|
||||
import argparse
|
||||
import bat_detect.utils.detector_utils as du
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
print('Loading model: ' + args['model_path'])
|
||||
model, params = du.load_model(args['model_path'])
|
||||
|
||||
print('\nInput directory: ' + args['audio_dir'])
|
||||
files = du.get_audio_files(args['audio_dir'])
|
||||
print('Number of audio files: {}'.format(len(files)))
|
||||
print('\nSaving results to: ' + args['ann_dir'])
|
||||
|
||||
# process files
|
||||
error_files = []
|
||||
for ii, audio_file in enumerate(files):
|
||||
print('\n' + str(ii).ljust(6) + os.path.basename(audio_file))
|
||||
try:
|
||||
results = du.process_file(audio_file, model, params, args)
|
||||
if args['save_preds_if_empty'] or (len(results['pred_dict']['annotation']) > 0):
|
||||
results_path = audio_file.replace(args['audio_dir'], args['ann_dir'])
|
||||
du.save_results_to_file(results, results_path)
|
||||
except:
|
||||
error_files.append(audio_file)
|
||||
print("Error processing file!")
|
||||
|
||||
print('\nResults saved to: ' + args['ann_dir'])
|
||||
|
||||
if len(error_files) > 0:
|
||||
print('\nUnable to process the follow files:')
|
||||
for err in error_files:
|
||||
print(' ' + err)
|
||||
|
||||
"""Run bat_detect.command.main() from the command line."""
|
||||
from bat_detect.cli import detect
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
info_str = '\nBatDetect2 - Detection and Classification\n' + \
|
||||
' Assumes audio files are mono, not stereo.\n' + \
|
||||
' Spaces in the input paths will throw an error. Wrap in quotes "".\n' + \
|
||||
' Input files should be short in duration e.g. < 30 seconds.\n'
|
||||
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('audio_dir', type=str, help='Input directory for audio')
|
||||
parser.add_argument('ann_dir', type=str, help='Output directory for where the predictions will be stored')
|
||||
parser.add_argument('detection_threshold', type=float, help='Cut-off probability for detector e.g. 0.1')
|
||||
parser.add_argument('--cnn_features', action='store_true', default=False, dest='cnn_features',
|
||||
help='Extracts CNN call features')
|
||||
parser.add_argument('--spec_features', action='store_true', default=False, dest='spec_features',
|
||||
help='Extracts low level call features')
|
||||
parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor',
|
||||
help='The time expansion factor used for all files (default is 1)')
|
||||
parser.add_argument('--quiet', action='store_true', default=False, dest='quiet',
|
||||
help='Minimize output printing')
|
||||
parser.add_argument('--save_preds_if_empty', action='store_true', default=False, dest='save_preds_if_empty',
|
||||
help='Save empty annotation file if no detections made.')
|
||||
parser.add_argument('--model_path', type=str, default='models/Net2DFast_UK_same.pth.tar',
|
||||
help='Path to trained BatDetect2 model')
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args['spec_slices'] = False # used for visualization
|
||||
args['chunk_size'] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks
|
||||
args['ann_dir'] = os.path.join(args['ann_dir'], '')
|
||||
|
||||
main(args)
|
||||
detect()
|
||||
|
@ -3,62 +3,95 @@ Loads a set of annotations corresponding to a dataset and saves an image which
|
||||
is the mean spectrogram for each class.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import viz_helpers as vz
|
||||
|
||||
sys.path.append(os.path.join('..'))
|
||||
import bat_detect.train.train_utils as tu
|
||||
sys.path.append(os.path.join(".."))
|
||||
import bat_detect.detector.parameters as parameters
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.train.train_split as ts
|
||||
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.utils.audio_utils as au
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('audio_path', type=str, help='Input directory for audio')
|
||||
parser.add_argument('op_dir', type=str,
|
||||
help='Path to where single annotation json file is stored')
|
||||
parser.add_argument('--ann_file', type=str,
|
||||
help='Path to where single annotation json file is stored')
|
||||
parser.add_argument('--uk_split', type=str, default='',
|
||||
help='Set as: diff or same')
|
||||
parser.add_argument('--file_type', type=str, default='png',
|
||||
help='Type of image to save png or pdf')
|
||||
parser.add_argument(
|
||||
"audio_path", type=str, help="Input directory for audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"op_dir",
|
||||
type=str,
|
||||
help="Path to where single annotation json file is stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ann_file",
|
||||
type=str,
|
||||
help="Path to where single annotation json file is stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--uk_split", type=str, default="", help="Set as: diff or same"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file_type",
|
||||
type=str,
|
||||
default="png",
|
||||
help="Type of image to save png or pdf",
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
if not os.path.isdir(args['op_dir']):
|
||||
os.makedirs(args['op_dir'])
|
||||
if not os.path.isdir(args["op_dir"]):
|
||||
os.makedirs(args["op_dir"])
|
||||
|
||||
params = parameters.get_params(False)
|
||||
params['smooth_spec'] = False
|
||||
params['spec_width'] = 48
|
||||
params['norm_type'] = 'log' # log, pcen
|
||||
params['aud_pad'] = 0.005
|
||||
classes_to_ignore = params['classes_to_ignore'] + params['generic_class']
|
||||
|
||||
params["smooth_spec"] = False
|
||||
params["spec_width"] = 48
|
||||
params["norm_type"] = "log" # log, pcen
|
||||
params["aud_pad"] = 0.005
|
||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||
|
||||
# load train annotations
|
||||
if args['uk_split'] == '':
|
||||
print('\nLoading:', args['ann_file'], '\n')
|
||||
dataset_name = os.path.basename(args['ann_file']).replace('.json', '')
|
||||
if args["uk_split"] == "":
|
||||
print("\nLoading:", args["ann_file"], "\n")
|
||||
dataset_name = os.path.basename(args["ann_file"]).replace(".json", "")
|
||||
datasets = []
|
||||
datasets.append(tu.get_blank_dataset_dict(dataset_name, False, args['ann_file'], args['audio_path']))
|
||||
datasets.append(
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name, False, args["ann_file"], args["audio_path"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
# load uk data - special case
|
||||
print('\nLoading:', args['uk_split'], '\n')
|
||||
dataset_name = 'uk_' + args['uk_split'] # should be uk_diff, or uk_same
|
||||
datasets, _ = ts.get_train_test_data(args['ann_file'], args['audio_path'], args['uk_split'], load_extra=False)
|
||||
print("\nLoading:", args["uk_split"], "\n")
|
||||
dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same
|
||||
datasets, _ = ts.get_train_test_data(
|
||||
args["ann_file"],
|
||||
args["audio_path"],
|
||||
args["uk_split"],
|
||||
load_extra=False,
|
||||
)
|
||||
|
||||
anns, class_names, _ = tu.load_set_of_anns(datasets, classes_to_ignore, params['events_of_interest'])
|
||||
anns, class_names, _ = tu.load_set_of_anns(
|
||||
datasets, classes_to_ignore, params["events_of_interest"]
|
||||
)
|
||||
class_names_order = range(len(class_names))
|
||||
|
||||
x_train, y_train = vz.load_data(anns, params, class_names, smooth_spec=params['smooth_spec'], norm_type=params['norm_type'])
|
||||
x_train, y_train = vz.load_data(
|
||||
anns,
|
||||
params,
|
||||
class_names,
|
||||
smooth_spec=params["smooth_spec"],
|
||||
norm_type=params["norm_type"],
|
||||
)
|
||||
|
||||
op_file_name = os.path.join(args['op_dir'], dataset_name + '.' + args['file_type'])
|
||||
vz.save_summary_image(x_train, y_train, class_names, params, op_file_name, class_names_order)
|
||||
print('\nImage saved to:', op_file_name)
|
||||
op_file_name = os.path.join(
|
||||
args["op_dir"], dataset_name + "." + args["file_type"]
|
||||
)
|
||||
vz.save_summary_image(
|
||||
x_train, y_train, class_names, params, op_file_name, class_names_order
|
||||
)
|
||||
print("\nImage saved to:", op_file_name)
|
||||
|
@ -7,24 +7,27 @@ Will save images with:
|
||||
3) spectrogram with predicted boxes
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join('..'))
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
sys.path.append(os.path.join(".."))
|
||||
import bat_detect.evaluate.evaluate_models as evlm
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.utils.detector_utils as du
|
||||
import bat_detect.utils.plot_utils as viz
|
||||
import bat_detect.utils.audio_utils as au
|
||||
|
||||
|
||||
def filter_anns(anns, start_time, stop_time):
|
||||
anns_op = []
|
||||
for aa in anns:
|
||||
if (aa['start_time'] >= start_time) and (aa['start_time'] < stop_time-0.02):
|
||||
if (aa["start_time"] >= start_time) and (
|
||||
aa["start_time"] < stop_time - 0.02
|
||||
):
|
||||
anns_op.append(aa)
|
||||
return anns_op
|
||||
|
||||
@ -32,85 +35,175 @@ def filter_anns(anns, start_time, stop_time):
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('audio_file', type=str, help='Path to audio file')
|
||||
parser.add_argument('model_path', type=str, help='Path to BatDetect model')
|
||||
parser.add_argument('--ann_file', type=str, default='', help='Path to annotation file')
|
||||
parser.add_argument('--op_dir', type=str, default='plots/',
|
||||
help='Output directory for plots')
|
||||
parser.add_argument('--file_type', type=str, default='png',
|
||||
help='Type of image to save png or pdf')
|
||||
parser.add_argument('--title_text', type=str, default='',
|
||||
help='Text to add as title of plots')
|
||||
parser.add_argument('--detection_threshold', type=float, default=0.2,
|
||||
help='Threshold for output detections')
|
||||
parser.add_argument('--start_time', type=float, default=0.0,
|
||||
help='Start time for cropped file')
|
||||
parser.add_argument('--stop_time', type=float, default=0.5,
|
||||
help='End time for cropped file')
|
||||
parser.add_argument('--time_expansion_factor', type=int, default=1,
|
||||
help='Time expansion factor')
|
||||
|
||||
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
||||
parser.add_argument("model_path", type=str, help="Path to BatDetect model")
|
||||
parser.add_argument(
|
||||
"--ann_file", type=str, default="", help="Path to annotation file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--op_dir",
|
||||
type=str,
|
||||
default="plots/",
|
||||
help="Output directory for plots",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file_type",
|
||||
type=str,
|
||||
default="png",
|
||||
help="Type of image to save png or pdf",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--title_text",
|
||||
type=str,
|
||||
default="",
|
||||
help="Text to add as title of plots",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detection_threshold",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Threshold for output detections",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_time",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Start time for cropped file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_time",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="End time for cropped file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_expansion_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Time expansion factor",
|
||||
)
|
||||
|
||||
args_cmd = vars(parser.parse_args())
|
||||
|
||||
# load the model
|
||||
bd_args = du.get_default_bd_args()
|
||||
model, params_bd = du.load_model(args_cmd['model_path'])
|
||||
bd_args['detection_threshold'] = args_cmd['detection_threshold']
|
||||
bd_args['time_expansion_factor'] = args_cmd['time_expansion_factor']
|
||||
|
||||
|
||||
# load the model
|
||||
bd_args = du.get_default_run_config()
|
||||
model, params_bd = du.load_model(args_cmd["model_path"])
|
||||
bd_args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||
|
||||
# load the annotation if it exists
|
||||
gt_present = False
|
||||
if args_cmd['ann_file'] != '':
|
||||
if os.path.isfile(args_cmd['ann_file']):
|
||||
with open(args_cmd['ann_file']) as da:
|
||||
if args_cmd["ann_file"] != "":
|
||||
if os.path.isfile(args_cmd["ann_file"]):
|
||||
with open(args_cmd["ann_file"]) as da:
|
||||
gt_anns = json.load(da)
|
||||
gt_anns = filter_anns(gt_anns['annotation'], args_cmd['start_time'], args_cmd['stop_time'])
|
||||
gt_anns = filter_anns(
|
||||
gt_anns["annotation"],
|
||||
args_cmd["start_time"],
|
||||
args_cmd["stop_time"],
|
||||
)
|
||||
gt_present = True
|
||||
else:
|
||||
print('Annotation file not found: ', args_cmd['ann_file'])
|
||||
print("Annotation file not found: ", args_cmd["ann_file"])
|
||||
|
||||
# load the audio file
|
||||
if not os.path.isfile(args_cmd['audio_file']):
|
||||
print('Audio file not found: ', args_cmd['audio_file'])
|
||||
if not os.path.isfile(args_cmd["audio_file"]):
|
||||
print("Audio file not found: ", args_cmd["audio_file"])
|
||||
sys.exit()
|
||||
|
||||
|
||||
# load audio and crop
|
||||
print('\nProcessing: ' + os.path.basename(args_cmd['audio_file']))
|
||||
print('\nOutput directory: ' + args_cmd['op_dir'])
|
||||
sampling_rate, audio = au.load_audio_file(args_cmd['audio_file'], args_cmd['time_exp'],
|
||||
params_bd['target_samp_rate'], params_bd['scale_raw_audio'])
|
||||
st_samp = int(sampling_rate*args_cmd['start_time'])
|
||||
en_samp = int(sampling_rate*args_cmd['stop_time'])
|
||||
print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
|
||||
print("\nOutput directory: " + args_cmd["op_dir"])
|
||||
sampling_rate, audio = au.load_audio(
|
||||
args_cmd["audio_file"],
|
||||
args_cmd["time_exp"],
|
||||
params_bd["target_samp_rate"],
|
||||
params_bd["scale_raw_audio"],
|
||||
)
|
||||
st_samp = int(sampling_rate * args_cmd["start_time"])
|
||||
en_samp = int(sampling_rate * args_cmd["stop_time"])
|
||||
if en_samp > audio.shape[0]:
|
||||
audio = np.hstack((audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype)))
|
||||
audio = np.hstack(
|
||||
(audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype))
|
||||
)
|
||||
audio = audio[st_samp:en_samp]
|
||||
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
print('File duration: {} seconds'.format(duration))
|
||||
print("File duration: {} seconds".format(duration))
|
||||
|
||||
# create spec for viz
|
||||
spec, _ = au.generate_spectrogram(audio, sampling_rate, params_bd, True, False)
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio, sampling_rate, params_bd, True, False
|
||||
)
|
||||
|
||||
run_config = {
|
||||
**params_bd,
|
||||
**bd_args,
|
||||
}
|
||||
|
||||
# run model and filter detections so only keep ones in relevant time range
|
||||
results = du.process_file(args_cmd['audio_file'], model, params_bd, bd_args)
|
||||
pred_anns = filter_anns(results['pred_dict']['annotation'], args_cmd['start_time'], args_cmd['stop_time'])
|
||||
print(len(pred_anns), 'Detections')
|
||||
results = du.process_file(args_cmd["audio_file"], model, run_config)
|
||||
pred_anns = filter_anns(
|
||||
results["pred_dict"]["annotation"],
|
||||
args_cmd["start_time"],
|
||||
args_cmd["stop_time"],
|
||||
)
|
||||
print(len(pred_anns), "Detections")
|
||||
|
||||
# save output
|
||||
if not os.path.isdir(args_cmd['op_dir']):
|
||||
os.makedirs(args_cmd['op_dir'])
|
||||
|
||||
if not os.path.isdir(args_cmd["op_dir"]):
|
||||
os.makedirs(args_cmd["op_dir"])
|
||||
|
||||
# create output file names
|
||||
op_path_clean = os.path.basename(args_cmd['audio_file'])[:-4] + '_clean.' + args_cmd['file_type']
|
||||
op_path_clean = os.path.join(args_cmd['op_dir'], op_path_clean)
|
||||
op_path_pred = os.path.basename(args_cmd['audio_file'])[:-4] + '_pred.' + args_cmd['file_type']
|
||||
op_path_pred = os.path.join(args_cmd['op_dir'], op_path_pred)
|
||||
op_path_clean = (
|
||||
os.path.basename(args_cmd["audio_file"])[:-4]
|
||||
+ "_clean."
|
||||
+ args_cmd["file_type"]
|
||||
)
|
||||
op_path_clean = os.path.join(args_cmd["op_dir"], op_path_clean)
|
||||
op_path_pred = (
|
||||
os.path.basename(args_cmd["audio_file"])[:-4]
|
||||
+ "_pred."
|
||||
+ args_cmd["file_type"]
|
||||
)
|
||||
op_path_pred = os.path.join(args_cmd["op_dir"], op_path_pred)
|
||||
|
||||
# create and save iamges
|
||||
viz.save_ann_spec(op_path_clean, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', None)
|
||||
viz.save_ann_spec(op_path_pred, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', pred_anns)
|
||||
viz.save_ann_spec(
|
||||
op_path_clean,
|
||||
spec,
|
||||
params_bd["min_freq"],
|
||||
params_bd["max_freq"],
|
||||
duration,
|
||||
args_cmd["start_time"],
|
||||
"",
|
||||
None,
|
||||
)
|
||||
viz.save_ann_spec(
|
||||
op_path_pred,
|
||||
spec,
|
||||
params_bd["min_freq"],
|
||||
params_bd["max_freq"],
|
||||
duration,
|
||||
args_cmd["start_time"],
|
||||
"",
|
||||
pred_anns,
|
||||
)
|
||||
|
||||
if gt_present:
|
||||
op_path_gt = os.path.basename(args_cmd['audio_file'])[:-4] + '_gt.' + args_cmd['file_type']
|
||||
op_path_gt = os.path.join(args_cmd['op_dir'], op_path_gt)
|
||||
viz.save_ann_spec(op_path_gt, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', gt_anns)
|
||||
op_path_gt = (
|
||||
os.path.basename(args_cmd["audio_file"])[:-4]
|
||||
+ "_gt."
|
||||
+ args_cmd["file_type"]
|
||||
)
|
||||
op_path_gt = os.path.join(args_cmd["op_dir"], op_path_gt)
|
||||
viz.save_ann_spec(
|
||||
op_path_gt,
|
||||
spec,
|
||||
params_bd["min_freq"],
|
||||
params_bd["max_freq"],
|
||||
duration,
|
||||
args_cmd["start_time"],
|
||||
"",
|
||||
gt_anns,
|
||||
)
|
||||
|
@ -8,163 +8,263 @@ Notes:
|
||||
Best to use system one - see ffmpeg_path.
|
||||
"""
|
||||
|
||||
from scipy.io import wavfile
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
import sys
|
||||
sys.path.append(os.path.join('..'))
|
||||
sys.path.append(os.path.join(".."))
|
||||
import bat_detect.detector.parameters as parameters
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.utils.plot_utils as viz
|
||||
import bat_detect.utils.detector_utils as du
|
||||
|
||||
import bat_detect.utils.plot_utils as viz
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('audio_file', type=str, help='Path to input audio file')
|
||||
parser.add_argument('model_path', type=str, help='Path to trained BatDetect model')
|
||||
parser.add_argument('--op_dir', type=str, default='generated_vids/', help='Path to output directory')
|
||||
parser.add_argument('--no_detector', action='store_true', help='Do not run detector')
|
||||
parser.add_argument('--plot_class_names_off', action='store_true', help='Do not plot class names')
|
||||
parser.add_argument('--disable_axis', action='store_true', help='Do not plot axis')
|
||||
parser.add_argument('--detection_threshold', type=float, default=0.2, help='Cut-off probability for detector')
|
||||
parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor',
|
||||
help='The time expansion factor used for all files (default is 1)')
|
||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
||||
parser.add_argument(
|
||||
"model_path", type=str, help="Path to trained BatDetect model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--op_dir",
|
||||
type=str,
|
||||
default="generated_vids/",
|
||||
help="Path to output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_detector", action="store_true", help="Do not run detector"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot_class_names_off",
|
||||
action="store_true",
|
||||
help="Do not plot class names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_axis", action="store_true", help="Do not plot axis"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detection_threshold",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Cut-off probability for detector",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_expansion_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="time_expansion_factor",
|
||||
help="The time expansion factor used for all files (default is 1)",
|
||||
)
|
||||
args_cmd = vars(parser.parse_args())
|
||||
|
||||
# file of interest
|
||||
audio_file = args_cmd['audio_file']
|
||||
op_dir = args_cmd['op_dir']
|
||||
op_str = '_output'
|
||||
ffmpeg_path = '/usr/bin/'
|
||||
audio_file = args_cmd["audio_file"]
|
||||
op_dir = args_cmd["op_dir"]
|
||||
op_str = "_output"
|
||||
ffmpeg_path = "/usr/bin/"
|
||||
|
||||
if not os.path.isfile(audio_file):
|
||||
print('Audio file not found: ', audio_file)
|
||||
print("Audio file not found: ", audio_file)
|
||||
sys.exit()
|
||||
|
||||
if not os.path.isfile(args_cmd['model_path']):
|
||||
print('Model not found: ', model_path)
|
||||
if not os.path.isfile(args_cmd["model_path"]):
|
||||
print("Model not found: ", model_path)
|
||||
sys.exit()
|
||||
|
||||
|
||||
start_time = 0.0
|
||||
duration = 0.5
|
||||
reveal_boxes = True # makes the boxes appear one at a time
|
||||
fps = 24
|
||||
dpi = 100
|
||||
|
||||
op_dir_tmp = os.path.join(op_dir, 'op_tmp_vids', '')
|
||||
op_dir_tmp = os.path.join(op_dir, "op_tmp_vids", "")
|
||||
if not os.path.isdir(op_dir_tmp):
|
||||
os.makedirs(op_dir_tmp)
|
||||
if not os.path.isdir(op_dir):
|
||||
os.makedirs(op_dir)
|
||||
|
||||
params = parameters.get_params(False)
|
||||
args = du.get_default_bd_args()
|
||||
args['time_expansion_factor'] = args_cmd['time_expansion_factor']
|
||||
args['detection_threshold'] = args_cmd['detection_threshold']
|
||||
|
||||
args = du.get_default_run_config()
|
||||
args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||
args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||
|
||||
# load audio file
|
||||
print('\nProcessing: ' + os.path.basename(audio_file))
|
||||
print('\nOutput directory: ' + op_dir)
|
||||
sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'])
|
||||
audio = audio[int(sampling_rate*start_time):int(sampling_rate*start_time + sampling_rate*duration)]
|
||||
print("\nProcessing: " + os.path.basename(audio_file))
|
||||
print("\nOutput directory: " + op_dir)
|
||||
sampling_rate, audio = au.load_audio(
|
||||
audio_file, args["time_expansion_factor"], params["target_samp_rate"]
|
||||
)
|
||||
audio = audio[
|
||||
int(sampling_rate * start_time) : int(
|
||||
sampling_rate * start_time + sampling_rate * duration
|
||||
)
|
||||
]
|
||||
audio_orig = audio.copy()
|
||||
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
|
||||
params['fft_overlap'], params['resize_factor'],
|
||||
params['spec_divide_factor'])
|
||||
audio = au.pad_audio(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
params["resize_factor"],
|
||||
params["spec_divide_factor"],
|
||||
)
|
||||
|
||||
# generate spectrogram
|
||||
spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True)
|
||||
max_val = spec.max()*1.1
|
||||
max_val = spec.max() * 1.1
|
||||
|
||||
if not args_cmd["no_detector"]:
|
||||
print(" Loading model and running detector on entire file ...")
|
||||
model, det_params = du.load_model(args_cmd["model_path"])
|
||||
det_params["detection_threshold"] = args["detection_threshold"]
|
||||
|
||||
if not args_cmd['no_detector']:
|
||||
print(' Loading model and running detector on entire file ...')
|
||||
model, det_params = du.load_model(args_cmd['model_path'])
|
||||
det_params['detection_threshold'] = args['detection_threshold']
|
||||
results = du.process_file(audio_file, model, det_params, args)
|
||||
run_config = {
|
||||
**det_params,
|
||||
**args,
|
||||
}
|
||||
results = du.process_file(audio_file, model, run_config)
|
||||
|
||||
print(' Processing detections and plotting ...')
|
||||
print(" Processing detections and plotting ...")
|
||||
detections = []
|
||||
for bb in results['pred_dict']['annotation']:
|
||||
if (bb['start_time'] >= start_time) and (bb['end_time'] < start_time+duration):
|
||||
for bb in results["pred_dict"]["annotation"]:
|
||||
if (bb["start_time"] >= start_time) and (
|
||||
bb["end_time"] < start_time + duration
|
||||
):
|
||||
detections.append(bb)
|
||||
|
||||
# plot boxes
|
||||
fig = plt.figure(1, figsize=(spec.shape[1]/dpi, spec.shape[0]/dpi), dpi=dpi)
|
||||
duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
||||
viz.create_box_image(spec, fig, detections, start_time, start_time+duration, duration, params, max_val,
|
||||
plot_class_names=not args_cmd['plot_class_names_off'])
|
||||
op_im_file_boxes = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '_boxes.png')
|
||||
fig = plt.figure(
|
||||
1, figsize=(spec.shape[1] / dpi, spec.shape[0] / dpi), dpi=dpi
|
||||
)
|
||||
duration = au.x_coords_to_time(
|
||||
spec.shape[1],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
viz.create_box_image(
|
||||
spec,
|
||||
fig,
|
||||
detections,
|
||||
start_time,
|
||||
start_time + duration,
|
||||
duration,
|
||||
params,
|
||||
max_val,
|
||||
plot_class_names=not args_cmd["plot_class_names_off"],
|
||||
)
|
||||
op_im_file_boxes = os.path.join(
|
||||
op_dir, os.path.basename(audio_file)[:-4] + op_str + "_boxes.png"
|
||||
)
|
||||
fig.savefig(op_im_file_boxes, dpi=dpi)
|
||||
plt.close(1)
|
||||
spec_with_boxes = plt.imread(op_im_file_boxes)
|
||||
|
||||
|
||||
print(' Saving audio file ...')
|
||||
if args['time_expansion_factor']==1:
|
||||
sampling_rate_op = int(sampling_rate/10.0)
|
||||
print(" Saving audio file ...")
|
||||
if args["time_expansion_factor"] == 1:
|
||||
sampling_rate_op = int(sampling_rate / 10.0)
|
||||
else:
|
||||
sampling_rate_op = sampling_rate
|
||||
op_audio_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.wav')
|
||||
op_audio_file = os.path.join(
|
||||
op_dir, os.path.basename(audio_file)[:-4] + op_str + ".wav"
|
||||
)
|
||||
wavfile.write(op_audio_file, sampling_rate_op, audio_orig)
|
||||
|
||||
|
||||
print(' Saving image ...')
|
||||
op_im_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.png')
|
||||
plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap='plasma')
|
||||
print(" Saving image ...")
|
||||
op_im_file = os.path.join(
|
||||
op_dir, os.path.basename(audio_file)[:-4] + op_str + ".png"
|
||||
)
|
||||
plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap="plasma")
|
||||
spec_blank = plt.imread(op_im_file)
|
||||
|
||||
# create figure
|
||||
freq_scale = 1000 # turn Hz to kHz
|
||||
min_freq = params['min_freq']//freq_scale
|
||||
max_freq = params['max_freq']//freq_scale
|
||||
min_freq = params["min_freq"] // freq_scale
|
||||
max_freq = params["max_freq"] // freq_scale
|
||||
y_extent = [0, duration, min_freq, max_freq]
|
||||
|
||||
print(' Saving video frames ...')
|
||||
print(" Saving video frames ...")
|
||||
# save images that will be combined into video
|
||||
# will either plot with or without boxes
|
||||
for ii, col in enumerate(np.linspace(0, spec.shape[1]-1, int(fps*duration*10))):
|
||||
if not args_cmd['no_detector']:
|
||||
for ii, col in enumerate(
|
||||
np.linspace(0, spec.shape[1] - 1, int(fps * duration * 10))
|
||||
):
|
||||
if not args_cmd["no_detector"]:
|
||||
spec_op = spec_with_boxes.copy()
|
||||
if ii > 0:
|
||||
spec_op[:, int(col), :] = 1.0
|
||||
if reveal_boxes:
|
||||
spec_op[:, int(col)+1:, :] = spec_blank[:, int(col)+1:, :]
|
||||
spec_op[:, int(col) + 1 :, :] = spec_blank[
|
||||
:, int(col) + 1 :, :
|
||||
]
|
||||
elif ii == 0 and reveal_boxes:
|
||||
spec_op = spec_blank
|
||||
|
||||
if not args_cmd['disable_axis']:
|
||||
plt.close('all')
|
||||
fig = plt.figure(ii, figsize=(1.2*(spec_op.shape[1]/dpi), 1.5*(spec_op.shape[0]/dpi)), dpi=dpi)
|
||||
plt.xlabel('Time - seconds')
|
||||
plt.ylabel('Frequency - kHz')
|
||||
plt.imshow(spec_op, vmin=0, vmax=1.0, cmap='plasma', extent=y_extent, aspect='auto')
|
||||
if not args_cmd["disable_axis"]:
|
||||
plt.close("all")
|
||||
fig = plt.figure(
|
||||
ii,
|
||||
figsize=(
|
||||
1.2 * (spec_op.shape[1] / dpi),
|
||||
1.5 * (spec_op.shape[0] / dpi),
|
||||
),
|
||||
dpi=dpi,
|
||||
)
|
||||
plt.xlabel("Time - seconds")
|
||||
plt.ylabel("Frequency - kHz")
|
||||
plt.imshow(
|
||||
spec_op,
|
||||
vmin=0,
|
||||
vmax=1.0,
|
||||
cmap="plasma",
|
||||
extent=y_extent,
|
||||
aspect="auto",
|
||||
)
|
||||
plt.tight_layout()
|
||||
fig.savefig(op_dir_tmp + str(ii).zfill(4) + '.png', dpi=dpi)
|
||||
fig.savefig(op_dir_tmp + str(ii).zfill(4) + ".png", dpi=dpi)
|
||||
else:
|
||||
plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=1.0, cmap='plasma')
|
||||
plt.imsave(
|
||||
op_dir_tmp + str(ii).zfill(4) + ".png",
|
||||
spec_op,
|
||||
vmin=0,
|
||||
vmax=1.0,
|
||||
cmap="plasma",
|
||||
)
|
||||
else:
|
||||
spec_op = spec.copy()
|
||||
if ii > 0:
|
||||
spec_op[:, int(col)] = max_val
|
||||
plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=max_val, cmap='plasma')
|
||||
plt.imsave(
|
||||
op_dir_tmp + str(ii).zfill(4) + ".png",
|
||||
spec_op,
|
||||
vmin=0,
|
||||
vmax=max_val,
|
||||
cmap="plasma",
|
||||
)
|
||||
|
||||
|
||||
print(' Creating video ...')
|
||||
op_vid_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.avi')
|
||||
ffmpeg_cmd = 'ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 ' \
|
||||
'-crf 25 -pix_fmt yuv420p -acodec copy {}'.format(fps, spec.shape[1], spec.shape[0], op_dir_tmp, op_audio_file, op_vid_file)
|
||||
print(" Creating video ...")
|
||||
op_vid_file = os.path.join(
|
||||
op_dir, os.path.basename(audio_file)[:-4] + op_str + ".avi"
|
||||
)
|
||||
ffmpeg_cmd = (
|
||||
"ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 "
|
||||
"-crf 25 -pix_fmt yuv420p -acodec copy {}".format(
|
||||
fps,
|
||||
spec.shape[1],
|
||||
spec.shape[0],
|
||||
op_dir_tmp,
|
||||
op_audio_file,
|
||||
op_vid_file,
|
||||
)
|
||||
)
|
||||
ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd
|
||||
os.system(ffmpeg_cmd)
|
||||
|
||||
print(' Deleting temporary files ...')
|
||||
print(" Deleting temporary files ...")
|
||||
if os.path.isdir(op_dir_tmp):
|
||||
shutil.rmtree(op_dir_tmp)
|
||||
shutil.rmtree(op_dir_tmp)
|
||||
|
@ -1,41 +1,70 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy import ndimage
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join('..'))
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy import ndimage
|
||||
|
||||
sys.path.append(os.path.join(".."))
|
||||
|
||||
import bat_detect.utils.audio_utils as au
|
||||
|
||||
|
||||
def generate_spectrogram_data(audio, sampling_rate, params, norm_type='log', smooth_spec=False):
|
||||
max_freq = round(params['max_freq']*params['fft_win_length'])
|
||||
min_freq = round(params['min_freq']*params['fft_win_length'])
|
||||
def generate_spectrogram_data(
|
||||
audio, sampling_rate, params, norm_type="log", smooth_spec=False
|
||||
):
|
||||
max_freq = round(params["max_freq"] * params["fft_win_length"])
|
||||
min_freq = round(params["min_freq"] * params["fft_win_length"])
|
||||
|
||||
# create spectrogram - numpy
|
||||
spec = au.gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
||||
#spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy()
|
||||
spec = au.gen_mag_spectrogram(
|
||||
audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
|
||||
)
|
||||
# spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy()
|
||||
if spec.shape[0] < max_freq:
|
||||
freq_pad = max_freq - spec.shape[0]
|
||||
spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec))
|
||||
spec = spec[-max_freq:spec.shape[0]-min_freq, :]
|
||||
spec = np.vstack(
|
||||
(np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec)
|
||||
)
|
||||
spec = spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
|
||||
if norm_type == 'log':
|
||||
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
|
||||
if norm_type == "log":
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (
|
||||
1.0
|
||||
/ (
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
##log_scaling = 0.01
|
||||
spec = np.log(1.0 + log_scaling*spec).astype(np.float32)
|
||||
elif norm_type == 'pcen':
|
||||
spec = np.log(1.0 + log_scaling * spec).astype(np.float32)
|
||||
elif norm_type == "pcen":
|
||||
spec = au.pcen(spec, sampling_rate)
|
||||
else:
|
||||
pass
|
||||
|
||||
if smooth_spec:
|
||||
spec = ndimage.gaussian_filter(spec, 1)
|
||||
spec = ndimage.gaussian_filter(spec, 1)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', extract_bg=False):
|
||||
def load_data(
|
||||
anns,
|
||||
params,
|
||||
class_names,
|
||||
smooth_spec=False,
|
||||
norm_type="log",
|
||||
extract_bg=False,
|
||||
):
|
||||
specs = []
|
||||
labels = []
|
||||
coords = []
|
||||
@ -43,67 +72,106 @@ def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', ext
|
||||
sampling_rates = []
|
||||
file_names = []
|
||||
for cur_file in anns:
|
||||
sampling_rate, audio_orig = au.load_audio_file(cur_file['file_path'], cur_file['time_exp'],
|
||||
params['target_samp_rate'], params['scale_raw_audio'])
|
||||
sampling_rate, audio_orig = au.load_audio(
|
||||
cur_file["file_path"],
|
||||
cur_file["time_exp"],
|
||||
params["target_samp_rate"],
|
||||
params["scale_raw_audio"],
|
||||
)
|
||||
|
||||
for ann in cur_file['annotation']:
|
||||
if ann['class'] not in params['classes_to_ignore'] and ann['class'] in class_names:
|
||||
for ann in cur_file["annotation"]:
|
||||
if (
|
||||
ann["class"] not in params["classes_to_ignore"]
|
||||
and ann["class"] in class_names
|
||||
):
|
||||
# clip out of bounds
|
||||
if ann['low_freq'] < params['min_freq']:
|
||||
ann['low_freq'] = params['min_freq']
|
||||
if ann['high_freq'] > params['max_freq']:
|
||||
ann['high_freq'] = params['max_freq']
|
||||
if ann["low_freq"] < params["min_freq"]:
|
||||
ann["low_freq"] = params["min_freq"]
|
||||
if ann["high_freq"] > params["max_freq"]:
|
||||
ann["high_freq"] = params["max_freq"]
|
||||
|
||||
# load cropped audio
|
||||
start_samp_diff = int(sampling_rate*ann['start_time']) - int(sampling_rate*params['aud_pad'])
|
||||
start_samp_diff = int(sampling_rate * ann["start_time"]) - int(
|
||||
sampling_rate * params["aud_pad"]
|
||||
)
|
||||
start_samp = np.maximum(0, start_samp_diff)
|
||||
end_samp = np.minimum(audio_orig.shape[0], int(sampling_rate*ann['end_time'])*2 + int(sampling_rate*params['aud_pad']))
|
||||
end_samp = np.minimum(
|
||||
audio_orig.shape[0],
|
||||
int(sampling_rate * ann["end_time"]) * 2
|
||||
+ int(sampling_rate * params["aud_pad"]),
|
||||
)
|
||||
audio = audio_orig[start_samp:end_samp]
|
||||
if start_samp_diff < 0:
|
||||
# need to pad at start if the call is at the very begining
|
||||
audio = np.hstack((np.zeros(-start_samp_diff, dtype=np.float32), audio))
|
||||
audio = np.hstack(
|
||||
(np.zeros(-start_samp_diff, dtype=np.float32), audio)
|
||||
)
|
||||
|
||||
nfft = int(params['fft_win_length']*sampling_rate)
|
||||
noverlap = int(params['fft_overlap']*nfft)
|
||||
max_samps = params['spec_width']*(nfft - noverlap) + noverlap
|
||||
nfft = int(params["fft_win_length"] * sampling_rate)
|
||||
noverlap = int(params["fft_overlap"] * nfft)
|
||||
max_samps = params["spec_width"] * (nfft - noverlap) + noverlap
|
||||
|
||||
if max_samps > audio.shape[0]:
|
||||
audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0])))
|
||||
audio = np.hstack(
|
||||
(audio, np.zeros(max_samps - audio.shape[0]))
|
||||
)
|
||||
audio = audio[:max_samps].astype(np.float32)
|
||||
|
||||
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
|
||||
params['fft_overlap'], params['resize_factor'],
|
||||
params['spec_divide_factor'])
|
||||
audio = au.pad_audio(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
params["resize_factor"],
|
||||
params["spec_divide_factor"],
|
||||
)
|
||||
|
||||
# generate spectrogram
|
||||
spec = generate_spectrogram_data(audio, sampling_rate, params, norm_type, smooth_spec)[:, :params['spec_width']]
|
||||
spec = generate_spectrogram_data(
|
||||
audio, sampling_rate, params, norm_type, smooth_spec
|
||||
)[:, : params["spec_width"]]
|
||||
|
||||
specs.append(spec[np.newaxis, ...])
|
||||
labels.append(ann['class'])
|
||||
labels.append(ann["class"])
|
||||
|
||||
audios.append(audio)
|
||||
sampling_rates.append(sampling_rate)
|
||||
file_names.append(cur_file['file_path'])
|
||||
file_names.append(cur_file["file_path"])
|
||||
|
||||
# position in crop
|
||||
x1 = int(au.time_to_x_coords(np.array(params['aud_pad']), sampling_rate, params['fft_win_length'], params['fft_overlap']))
|
||||
y1 = (ann['low_freq'] - params['min_freq']) * params['fft_win_length']
|
||||
x1 = int(
|
||||
au.time_to_x_coords(
|
||||
np.array(params["aud_pad"]),
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
)
|
||||
y1 = (ann["low_freq"] - params["min_freq"]) * params[
|
||||
"fft_win_length"
|
||||
]
|
||||
coords.append((y1, x1))
|
||||
|
||||
|
||||
_, file_ids = np.unique(file_names, return_inverse=True)
|
||||
labels = np.array([class_names.index(ll) for ll in labels])
|
||||
|
||||
#return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names
|
||||
# return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names
|
||||
return np.vstack(specs), labels
|
||||
|
||||
|
||||
def save_summary_image(specs, labels, species_names, params, op_file_name='plots/all_species.png', order=None):
|
||||
def save_summary_image(
|
||||
specs,
|
||||
labels,
|
||||
species_names,
|
||||
params,
|
||||
op_file_name="plots/all_species.png",
|
||||
order=None,
|
||||
):
|
||||
# takes the mean for each class and plots it on a grid
|
||||
mean_specs = []
|
||||
max_band = []
|
||||
for ii in range(len(species_names)):
|
||||
inds = np.where(labels==ii)[0]
|
||||
inds = np.where(labels == ii)[0]
|
||||
mu = specs[inds, :].mean(0)
|
||||
max_band.append(np.argmax(mu.sum(1)))
|
||||
mean_specs.append(mu)
|
||||
@ -113,11 +181,21 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots
|
||||
order = np.arange(len(species_names))
|
||||
|
||||
max_cols = 6
|
||||
nrows = int(np.ceil(len(species_names)/max_cols))
|
||||
nrows = int(np.ceil(len(species_names) / max_cols))
|
||||
ncols = np.minimum(len(species_names), max_cols)
|
||||
|
||||
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.3, nrows*6), gridspec_kw = {'wspace':0, 'hspace':0.2})
|
||||
spec_min_max = (0, mean_specs[0].shape[1], params['min_freq']/1000, params['max_freq']/1000)
|
||||
fig, ax = plt.subplots(
|
||||
nrows=nrows,
|
||||
ncols=ncols,
|
||||
figsize=(ncols * 3.3, nrows * 6),
|
||||
gridspec_kw={"wspace": 0, "hspace": 0.2},
|
||||
)
|
||||
spec_min_max = (
|
||||
0,
|
||||
mean_specs[0].shape[1],
|
||||
params["min_freq"] / 1000,
|
||||
params["max_freq"] / 1000,
|
||||
)
|
||||
ii = 0
|
||||
for row in ax:
|
||||
|
||||
@ -126,17 +204,22 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots
|
||||
|
||||
for col in row:
|
||||
if ii >= len(species_names):
|
||||
col.axis('off')
|
||||
col.axis("off")
|
||||
else:
|
||||
inds = np.where(labels==order[ii])[0]
|
||||
col.imshow(mean_specs[order[ii]], extent=spec_min_max, cmap='plasma', aspect='equal')
|
||||
col.grid(color='w', alpha=0.3, linewidth=0.3)
|
||||
inds = np.where(labels == order[ii])[0]
|
||||
col.imshow(
|
||||
mean_specs[order[ii]],
|
||||
extent=spec_min_max,
|
||||
cmap="plasma",
|
||||
aspect="equal",
|
||||
)
|
||||
col.grid(color="w", alpha=0.3, linewidth=0.3)
|
||||
col.set_xticks([])
|
||||
col.title.set_text(str(ii+1) + ' ' + species_names[order[ii]])
|
||||
col.tick_params(axis='both', which='major', labelsize=7)
|
||||
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
|
||||
col.tick_params(axis="both", which="major", labelsize=7)
|
||||
ii += 1
|
||||
|
||||
#plt.tight_layout()
|
||||
#plt.show()
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
plt.savefig(op_file_name)
|
||||
plt.close('all')
|
||||
plt.close("all")
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
253
tests/test_api.py
Normal file
253
tests/test_api.py
Normal file
@ -0,0 +1,253 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect import api
|
||||
|
||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||
|
||||
|
||||
def test_load_model_with_default_params():
|
||||
"""Test loading model with default parameters."""
|
||||
model, params = api.load_model()
|
||||
|
||||
assert model is not None
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
assert params is not None
|
||||
assert isinstance(params, dict)
|
||||
|
||||
assert "model_name" in params
|
||||
assert "num_filters" in params
|
||||
assert "emb_dim" in params
|
||||
assert "ip_height" in params
|
||||
|
||||
assert params["model_name"] == "Net2DFast"
|
||||
assert params["num_filters"] == 128
|
||||
assert params["emb_dim"] == 0
|
||||
assert params["ip_height"] == 128
|
||||
assert params["resize_factor"] == 0.5
|
||||
assert len(params["class_names"]) == 17
|
||||
|
||||
|
||||
def test_list_audio_files():
|
||||
"""Test listing audio files."""
|
||||
audio_files = api.list_audio_files(TEST_DATA_DIR)
|
||||
|
||||
assert len(audio_files) == 3
|
||||
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||
|
||||
|
||||
def test_load_audio():
|
||||
"""Test loading audio."""
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
|
||||
assert audio is not None
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.shape == (128000,)
|
||||
|
||||
|
||||
def test_generate_spectrogram():
|
||||
"""Test generating spectrogram."""
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
spectrogram = api.generate_spectrogram(audio)
|
||||
|
||||
assert spectrogram is not None
|
||||
assert isinstance(spectrogram, torch.Tensor)
|
||||
assert spectrogram.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_get_default_config():
|
||||
"""Test getting default configuration."""
|
||||
config = api.get_config()
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
|
||||
assert config["target_samp_rate"] == 256000
|
||||
assert config["fft_win_length"] == 0.002
|
||||
assert config["fft_overlap"] == 0.75
|
||||
assert config["resize_factor"] == 0.5
|
||||
assert config["spec_divide_factor"] == 32
|
||||
assert config["spec_height"] == 256
|
||||
assert config["spec_scale"] == "pcen"
|
||||
assert config["denoise_spec_avg"] is True
|
||||
assert config["max_scale_spec"] is False
|
||||
assert config["scale_raw_audio"] is False
|
||||
assert len(config["class_names"]) == 0
|
||||
assert config["detection_threshold"] == 0.01
|
||||
assert config["time_expansion"] == 1
|
||||
assert config["top_n"] == 3
|
||||
assert config["return_raw_preds"] is False
|
||||
assert config["max_duration"] is None
|
||||
assert config["nms_kernel_size"] == 9
|
||||
assert config["max_freq"] == 120000
|
||||
assert config["min_freq"] == 10000
|
||||
assert config["nms_top_k_per_sec"] == 200
|
||||
assert config["quiet"] is True
|
||||
assert config["chunk_size"] == 3
|
||||
assert config["cnn_features"] is False
|
||||
assert config["spec_features"] is False
|
||||
assert config["spec_slices"] is False
|
||||
|
||||
|
||||
def test_api_exposes_default_model():
|
||||
"""Test that API exposes default model."""
|
||||
assert hasattr(api, "model")
|
||||
assert isinstance(api.model, nn.Module)
|
||||
assert type(api.model).__name__ == "Net2DFast"
|
||||
|
||||
# Check that model has expected attributes
|
||||
assert api.model.num_classes == 17
|
||||
assert api.model.num_filts == 128
|
||||
assert api.model.emb_dim == 0
|
||||
assert api.model.ip_height_rs == 128
|
||||
assert api.model.resize_factor == 0.5
|
||||
|
||||
|
||||
def test_api_exposes_default_config():
|
||||
"""Test that API exposes default configuration."""
|
||||
assert hasattr(api, "config")
|
||||
assert isinstance(api.config, dict)
|
||||
|
||||
assert api.config["target_samp_rate"] == 256000
|
||||
assert api.config["fft_win_length"] == 0.002
|
||||
assert api.config["fft_overlap"] == 0.75
|
||||
assert api.config["resize_factor"] == 0.5
|
||||
assert api.config["spec_divide_factor"] == 32
|
||||
assert api.config["spec_height"] == 256
|
||||
assert api.config["spec_scale"] == "pcen"
|
||||
assert api.config["denoise_spec_avg"] is True
|
||||
assert api.config["max_scale_spec"] is False
|
||||
assert api.config["scale_raw_audio"] is False
|
||||
assert len(api.config["class_names"]) == 17
|
||||
assert api.config["detection_threshold"] == 0.01
|
||||
assert api.config["time_expansion"] == 1
|
||||
assert api.config["top_n"] == 3
|
||||
assert api.config["return_raw_preds"] is False
|
||||
assert api.config["max_duration"] is None
|
||||
assert api.config["nms_kernel_size"] == 9
|
||||
assert api.config["max_freq"] == 120000
|
||||
assert api.config["min_freq"] == 10000
|
||||
assert api.config["nms_top_k_per_sec"] == 200
|
||||
assert api.config["quiet"] is True
|
||||
assert api.config["chunk_size"] == 3
|
||||
assert api.config["cnn_features"] is False
|
||||
assert api.config["spec_features"] is False
|
||||
assert api.config["spec_slices"] is False
|
||||
|
||||
|
||||
def test_process_file_with_default_model():
|
||||
"""Test processing file with model."""
|
||||
predictions = api.process_file(TEST_DATA[0])
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, dict)
|
||||
|
||||
assert "pred_dict" in predictions
|
||||
|
||||
# By default will not return other features
|
||||
assert "spec_feats" not in predictions
|
||||
assert "spec_feat_names" not in predictions
|
||||
assert "cnn_feats" not in predictions
|
||||
assert "cnn_feat_names" not in predictions
|
||||
assert "spec_slices" not in predictions
|
||||
|
||||
# Check that predictions are returned
|
||||
assert isinstance(predictions["pred_dict"], dict)
|
||||
pred_dict = predictions["pred_dict"]
|
||||
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
||||
assert pred_dict["annotated"] is False
|
||||
assert pred_dict["issues"] is False
|
||||
assert pred_dict["notes"] == "Automatically generated."
|
||||
assert pred_dict["time_exp"] == 1
|
||||
assert pred_dict["duration"] == 0.5
|
||||
assert pred_dict["class_name"] is not None
|
||||
assert len(pred_dict["annotation"]) > 0
|
||||
|
||||
|
||||
def test_process_spectrogram_with_default_model():
|
||||
"""Test processing spectrogram with model."""
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
spectrogram = api.generate_spectrogram(audio)
|
||||
predictions, features = api.process_spectrogram(spectrogram)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
|
||||
|
||||
def test_process_audio_with_default_model():
|
||||
"""Test processing audio with model."""
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
predictions, features, spec = api.process_audio(audio)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
|
||||
assert spec is not None
|
||||
assert isinstance(spec, torch.Tensor)
|
||||
assert spec.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_postprocess_model_outputs():
|
||||
"""Test postprocessing model outputs."""
|
||||
# Load model outputs
|
||||
audio = api.load_audio(TEST_DATA[1])
|
||||
spec = api.generate_spectrogram(audio)
|
||||
model_outputs = api.model(spec)
|
||||
|
||||
# Postprocess outputs
|
||||
predictions, features = api.postprocess(model_outputs)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, np.ndarray)
|
||||
assert features.shape[0] == len(predictions)
|
||||
assert features.shape[1] == 32
|
41
tests/test_cli.py
Normal file
41
tests/test_cli.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Test the command line interface."""
|
||||
from click.testing import CliRunner
|
||||
|
||||
from bat_detect.cli import cli
|
||||
|
||||
|
||||
def test_cli_base_command():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||
|
||||
|
||||
def test_cli_detect_command_help():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["detect", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||
|
||||
|
||||
def test_cli_detect_command_on_test_audio(tmp_path):
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||
assert len(list(results_dir.glob("*.json"))) == 3
|
Loading…
Reference in New Issue
Block a user