mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 07:02:01 +02:00
Compare commits
23 Commits
a462beaeb8
...
e8db1d4050
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e8db1d4050 | ||
![]() |
b396d4908a | ||
![]() |
ce15a0f152 | ||
![]() |
16febed792 | ||
![]() |
d67ae9be05 | ||
![]() |
15de168a20 | ||
![]() |
0c8fae4a72 | ||
![]() |
6d57f96c07 | ||
![]() |
a0622aa9a4 | ||
![]() |
587742b41e | ||
![]() |
6d91153a56 | ||
![]() |
70c96b6844 | ||
![]() |
22f7d46f46 | ||
![]() |
1384c549f7 | ||
![]() |
4b6acd5e6e | ||
![]() |
2ac968d65b | ||
![]() |
166dad20bd | ||
![]() |
cbb02cf69e | ||
![]() |
647468123e | ||
![]() |
363eb9fb2f | ||
![]() |
b21c224985 | ||
![]() |
152a577511 | ||
![]() |
136949c4e7 |
3
.dvc/.gitignore
vendored
Normal file
3
.dvc/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
/config.local
|
||||||
|
/tmp
|
||||||
|
/cache
|
0
.dvc/config
Normal file
0
.dvc/config
Normal file
3
.dvcignore
Normal file
3
.dvcignore
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Add patterns of files dvc should ignore, which could improve
|
||||||
|
# the performance. Learn more at
|
||||||
|
# https://dvc.org/doc/user-guide/dvcignore
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -114,3 +114,5 @@ experiments/*
|
|||||||
notebooks/lightning_logs
|
notebooks/lightning_logs
|
||||||
example_data/preprocessed
|
example_data/preprocessed
|
||||||
.aider*
|
.aider*
|
||||||
|
DvcLiveLogger/checkpoints
|
||||||
|
logs
|
||||||
|
94
Makefile
94
Makefile
@ -1,94 +0,0 @@
|
|||||||
# Variables
|
|
||||||
SOURCE_DIR = src
|
|
||||||
TESTS_DIR = tests
|
|
||||||
PYTHON_DIRS = src tests
|
|
||||||
DOCS_SOURCE = docs/source
|
|
||||||
DOCS_BUILD = docs/build
|
|
||||||
HTML_COVERAGE_DIR = htmlcov
|
|
||||||
|
|
||||||
# Default target (optional, often 'help' or 'all')
|
|
||||||
.DEFAULT_GOAL := help
|
|
||||||
|
|
||||||
# Phony targets (targets that don't produce a file with the same name)
|
|
||||||
.PHONY: help test coverage coverage-html coverage-serve docs docs-serve format format-check lint lint-fix typecheck check clean clean-pyc clean-test clean-docs clean-build
|
|
||||||
|
|
||||||
help:
|
|
||||||
@echo "Makefile Targets:"
|
|
||||||
@echo " help Show this help message."
|
|
||||||
@echo " test Run tests using pytest."
|
|
||||||
@echo " coverage Run tests and generate coverage data (.coverage, coverage.xml)."
|
|
||||||
@echo " coverage-html Generate an HTML coverage report in $(HTML_COVERAGE_DIR)/."
|
|
||||||
@echo " coverage-serve Serve the HTML coverage report locally."
|
|
||||||
@echo " docs Build documentation using Sphinx."
|
|
||||||
@echo " docs-serve Serve documentation with live reload using sphinx-autobuild."
|
|
||||||
@echo " format Format code using ruff."
|
|
||||||
@echo " format-check Check code formatting using ruff."
|
|
||||||
@echo " lint Lint code using ruff."
|
|
||||||
@echo " lint-fix Lint code using ruff and apply automatic fixes."
|
|
||||||
@echo " typecheck Type check code using pyright."
|
|
||||||
@echo " check Run all checks (format-check, lint, typecheck)."
|
|
||||||
@echo " clean Remove all build, test, documentation, and Python artifacts."
|
|
||||||
@echo " clean-pyc Remove Python bytecode and cache."
|
|
||||||
@echo " clean-test Remove test and coverage artifacts."
|
|
||||||
@echo " clean-docs Remove built documentation."
|
|
||||||
@echo " clean-build Remove package build artifacts."
|
|
||||||
|
|
||||||
# Testing & Coverage
|
|
||||||
test:
|
|
||||||
pytest tests
|
|
||||||
|
|
||||||
coverage:
|
|
||||||
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml tests
|
|
||||||
|
|
||||||
coverage-html: coverage
|
|
||||||
@echo "Generating HTML coverage report..."
|
|
||||||
coverage html -d $(HTML_COVERAGE_DIR)
|
|
||||||
@echo "HTML coverage report generated in $(HTML_COVERAGE_DIR)/"
|
|
||||||
|
|
||||||
coverage-serve: coverage-html
|
|
||||||
@echo "Serving report at http://localhost:8000/ ..."
|
|
||||||
python -m http.server --directory $(HTML_COVERAGE_DIR) 8000
|
|
||||||
|
|
||||||
# Documentation
|
|
||||||
docs:
|
|
||||||
sphinx-build -b html $(DOCS_SOURCE) $(DOCS_BUILD)
|
|
||||||
|
|
||||||
docs-serve:
|
|
||||||
sphinx-autobuild $(DOCS_SOURCE) $(DOCS_BUILD) --watch $(SOURCE_DIR) --open-browser
|
|
||||||
|
|
||||||
# Formatting & Linting
|
|
||||||
format:
|
|
||||||
ruff format $(PYTHON_DIRS)
|
|
||||||
|
|
||||||
format-check:
|
|
||||||
ruff format --check $(PYTHON_DIRS)
|
|
||||||
|
|
||||||
lint:
|
|
||||||
ruff check $(PYTHON_DIRS)
|
|
||||||
|
|
||||||
lint-fix:
|
|
||||||
ruff check --fix $(PYTHON_DIRS)
|
|
||||||
|
|
||||||
# Type Checking
|
|
||||||
typecheck:
|
|
||||||
pyright $(PYTHON_DIRS)
|
|
||||||
|
|
||||||
# Combined Checks
|
|
||||||
check: format-check lint typecheck test
|
|
||||||
|
|
||||||
# Cleaning tasks
|
|
||||||
clean-pyc:
|
|
||||||
find . -type f -name "*.py[co]" -delete
|
|
||||||
find . -type d -name "__pycache__" -exec rm -rf {} +
|
|
||||||
|
|
||||||
clean-test:
|
|
||||||
rm -f .coverage coverage.xml
|
|
||||||
rm -rf .pytest_cache htmlcov/
|
|
||||||
|
|
||||||
clean-docs:
|
|
||||||
rm -rf $(DOCS_BUILD)
|
|
||||||
|
|
||||||
clean-build:
|
|
||||||
rm -rf build/ dist/ *.egg-info/
|
|
||||||
|
|
||||||
clean: clean-build clean-pyc clean-test clean-docs
|
|
146
example_data/config.yaml
Normal file
146
example_data/config.yaml
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
targets:
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: myomys
|
||||||
|
tags:
|
||||||
|
- value: Myotis mystacinus
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- value: Pipistrellus pipistrellus
|
||||||
|
- name: eptser
|
||||||
|
tags:
|
||||||
|
- value: Eptesicus serotinus
|
||||||
|
- name: rhifer
|
||||||
|
tags:
|
||||||
|
- value: Rhinolophus ferrumequinum
|
||||||
|
generic_class:
|
||||||
|
- key: class
|
||||||
|
value: Bat
|
||||||
|
|
||||||
|
filtering:
|
||||||
|
rules:
|
||||||
|
- match_type: all
|
||||||
|
tags:
|
||||||
|
- key: event
|
||||||
|
value: Echolocation
|
||||||
|
- match_type: exclude
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Unknown
|
||||||
|
|
||||||
|
preprocess:
|
||||||
|
audio:
|
||||||
|
resample:
|
||||||
|
samplerate: 256000
|
||||||
|
method: "poly"
|
||||||
|
scale: false
|
||||||
|
center: true
|
||||||
|
duration: null
|
||||||
|
|
||||||
|
spectrogram:
|
||||||
|
stft:
|
||||||
|
window_duration: 0.002
|
||||||
|
window_overlap: 0.75
|
||||||
|
window_fn: hann
|
||||||
|
frequencies:
|
||||||
|
max_freq: 120000
|
||||||
|
min_freq: 10000
|
||||||
|
pcen:
|
||||||
|
time_constant: 0.4
|
||||||
|
gain: 0.98
|
||||||
|
bias: 2
|
||||||
|
power: 0.5
|
||||||
|
scale: "amplitude"
|
||||||
|
size:
|
||||||
|
height: 128
|
||||||
|
resize_factor: 0.5
|
||||||
|
spectral_mean_substraction: true
|
||||||
|
peak_normalize: false
|
||||||
|
|
||||||
|
postprocess:
|
||||||
|
nms_kernel_size: 9
|
||||||
|
detection_threshold: 0.01
|
||||||
|
min_freq: 10000
|
||||||
|
max_freq: 120000
|
||||||
|
top_k_per_sec: 200
|
||||||
|
|
||||||
|
labels:
|
||||||
|
sigma: 3
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_height: 128
|
||||||
|
in_channels: 1
|
||||||
|
out_channels: 32
|
||||||
|
encoder:
|
||||||
|
layers:
|
||||||
|
- block_type: FreqCoordConvDown
|
||||||
|
out_channels: 32
|
||||||
|
- block_type: FreqCoordConvDown
|
||||||
|
out_channels: 64
|
||||||
|
- block_type: LayerGroup
|
||||||
|
layers:
|
||||||
|
- block_type: FreqCoordConvDown
|
||||||
|
out_channels: 128
|
||||||
|
- block_type: ConvBlock
|
||||||
|
out_channels: 256
|
||||||
|
bottleneck:
|
||||||
|
channels: 256
|
||||||
|
self_attention: true
|
||||||
|
decoder:
|
||||||
|
layers:
|
||||||
|
- block_type: FreqCoordConvUp
|
||||||
|
out_channels: 64
|
||||||
|
- block_type: FreqCoordConvUp
|
||||||
|
out_channels: 32
|
||||||
|
- block_type: LayerGroup
|
||||||
|
layers:
|
||||||
|
- block_type: FreqCoordConvUp
|
||||||
|
out_channels: 32
|
||||||
|
- block_type: ConvBlock
|
||||||
|
out_channels: 32
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 8
|
||||||
|
learning_rate: 0.001
|
||||||
|
t_max: 100
|
||||||
|
loss:
|
||||||
|
detection:
|
||||||
|
weight: 1.0
|
||||||
|
focal:
|
||||||
|
beta: 4
|
||||||
|
alpha: 2
|
||||||
|
classification:
|
||||||
|
weight: 2.0
|
||||||
|
focal:
|
||||||
|
beta: 4
|
||||||
|
alpha: 2
|
||||||
|
size:
|
||||||
|
weight: 0.1
|
||||||
|
logger:
|
||||||
|
logger_type: dvclive
|
||||||
|
augmentations:
|
||||||
|
steps:
|
||||||
|
- augmentation_type: mix_audio
|
||||||
|
probability: 0.2
|
||||||
|
min_weight: 0.3
|
||||||
|
max_weight: 0.7
|
||||||
|
- augmentation_type: add_echo
|
||||||
|
probability: 0.2
|
||||||
|
max_delay: 0.005
|
||||||
|
min_weight: 0.0
|
||||||
|
max_weight: 1.0
|
||||||
|
- augmentation_type: scale_volume
|
||||||
|
probability: 0.2
|
||||||
|
min_scaling: 0.0
|
||||||
|
max_scaling: 2.0
|
||||||
|
- augmentation_type: warp
|
||||||
|
probability: 0.2
|
||||||
|
delta: 0.04
|
||||||
|
- augmentation_type: mask_time
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.05
|
||||||
|
max_masks: 3
|
||||||
|
- augmentation_type: mask_freq
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.10
|
||||||
|
max_masks: 3
|
110
justfile
Normal file
110
justfile
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
# Default command, runs if no recipe is specified.
|
||||||
|
default:
|
||||||
|
just --list
|
||||||
|
|
||||||
|
# Variables
|
||||||
|
SOURCE_DIR := "src"
|
||||||
|
TESTS_DIR := "tests"
|
||||||
|
PYTHON_DIRS := "src tests"
|
||||||
|
DOCS_SOURCE := "docs/source"
|
||||||
|
DOCS_BUILD := "docs/build"
|
||||||
|
HTML_COVERAGE_DIR := "htmlcov"
|
||||||
|
|
||||||
|
# Show available commands
|
||||||
|
help:
|
||||||
|
@just --list
|
||||||
|
|
||||||
|
# Testing & Coverage
|
||||||
|
# Run tests using pytest.
|
||||||
|
test:
|
||||||
|
pytest {{TESTS_DIR}}
|
||||||
|
|
||||||
|
# Run tests and generate coverage data.
|
||||||
|
coverage:
|
||||||
|
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||||
|
|
||||||
|
# Generate an HTML coverage report.
|
||||||
|
coverage-html: coverage
|
||||||
|
@echo "Generating HTML coverage report..."
|
||||||
|
coverage html -d {{HTML_COVERAGE_DIR}}
|
||||||
|
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
||||||
|
|
||||||
|
# Serve the HTML coverage report locally.
|
||||||
|
coverage-serve: coverage-html
|
||||||
|
@echo "Serving report at http://localhost:8000/ ..."
|
||||||
|
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
# Build documentation using Sphinx.
|
||||||
|
docs:
|
||||||
|
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||||
|
|
||||||
|
# Serve documentation with live reload.
|
||||||
|
docs-serve:
|
||||||
|
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||||
|
|
||||||
|
# Formatting & Linting
|
||||||
|
# Format code using ruff.
|
||||||
|
format:
|
||||||
|
ruff format {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Check code formatting using ruff.
|
||||||
|
format-check:
|
||||||
|
ruff format --check {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Lint code using ruff.
|
||||||
|
lint:
|
||||||
|
ruff check {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Lint code using ruff and apply automatic fixes.
|
||||||
|
lint-fix:
|
||||||
|
ruff check --fix {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Type Checking
|
||||||
|
# Type check code using pyright.
|
||||||
|
typecheck:
|
||||||
|
pyright {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Combined Checks
|
||||||
|
# Run all checks (format-check, lint, typecheck).
|
||||||
|
check: format-check lint typecheck test
|
||||||
|
|
||||||
|
# Cleaning tasks
|
||||||
|
# Remove Python bytecode and cache.
|
||||||
|
clean-pyc:
|
||||||
|
find . -type f -name "*.py[co]" -delete
|
||||||
|
find . -type d -name "__pycache__" -exec rm -rf {} +
|
||||||
|
|
||||||
|
# Remove test and coverage artifacts.
|
||||||
|
clean-test:
|
||||||
|
rm -f .coverage coverage.xml
|
||||||
|
rm -rf .pytest_cache htmlcov/
|
||||||
|
|
||||||
|
# Remove built documentation.
|
||||||
|
clean-docs:
|
||||||
|
rm -rf {{DOCS_BUILD}}
|
||||||
|
|
||||||
|
# Remove package build artifacts.
|
||||||
|
clean-build:
|
||||||
|
rm -rf build/ dist/ *.egg-info/
|
||||||
|
|
||||||
|
# Remove all build, test, documentation, and Python artifacts.
|
||||||
|
clean: clean-build clean-pyc clean-test clean-docs
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
# Preprocess example data.
|
||||||
|
example-preprocess OPTIONS="":
|
||||||
|
batdetect2 preprocess \
|
||||||
|
--base-dir . \
|
||||||
|
--dataset-field datasets.train \
|
||||||
|
--config example_data/config.yaml \
|
||||||
|
{{OPTIONS}} \
|
||||||
|
example_data/datasets.yaml example_data/preprocessed
|
||||||
|
|
||||||
|
# Train on example data.
|
||||||
|
example-train OPTIONS="":
|
||||||
|
batdetect2 train \
|
||||||
|
--val-dir example_data/preprocessed \
|
||||||
|
--config example_data/config.yaml \
|
||||||
|
{{OPTIONS}} \
|
||||||
|
example_data/preprocessed
|
@ -85,6 +85,11 @@ dev = [
|
|||||||
"sphinx-book-theme>=1.1.4",
|
"sphinx-book-theme>=1.1.4",
|
||||||
"autodoc-pydantic>=2.2.0",
|
"autodoc-pydantic>=2.2.0",
|
||||||
"pytest-cov>=6.1.1",
|
"pytest-cov>=6.1.1",
|
||||||
|
"ty>=0.0.1a12",
|
||||||
|
"rust-just>=1.40.0",
|
||||||
|
]
|
||||||
|
dvclive = [
|
||||||
|
"dvclive>=3.48.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
@ -2,13 +2,13 @@ from batdetect2.cli.base import cli
|
|||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.preprocess import preprocess
|
from batdetect2.cli.preprocess import preprocess
|
||||||
from batdetect2.cli.train import train
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cli",
|
"cli",
|
||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train",
|
"train_command",
|
||||||
"preprocess",
|
"preprocess",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
import yaml
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
from batdetect2.train.preprocess import (
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
TrainPreprocessConfig,
|
||||||
from batdetect2.train import load_label_config, preprocess_annotations
|
load_train_preprocessing_config,
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
preprocess_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["preprocess"]
|
__all__ = ["preprocess"]
|
||||||
|
|
||||||
@ -44,16 +47,16 @@ __all__ = ["preprocess"]
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config",
|
"--config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help=(
|
help=(
|
||||||
"Path to the preprocessing configuration file. This file tells "
|
"Path to the configuration file. This file tells "
|
||||||
"the program how to prepare your audio data before training, such "
|
"the program how to prepare your audio data before training, such "
|
||||||
"as resampling or applying filters."
|
"as resampling or applying filters."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config-field",
|
"--config-field",
|
||||||
type=str,
|
type=str,
|
||||||
help=(
|
help=(
|
||||||
"If the preprocessing settings are inside a nested dictionary "
|
"If the preprocessing settings are inside a nested dictionary "
|
||||||
@ -62,41 +65,6 @@ __all__ = ["preprocess"]
|
|||||||
"top level, you don't need to specify this."
|
"top level, you don't need to specify this."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
|
||||||
"--label-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the label generation configuration file. This file "
|
|
||||||
"contains settings for how to create labels from your "
|
|
||||||
"annotations, which the model uses to learn."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--label-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the label generation settings are inside a nested dictionary "
|
|
||||||
"within the label configuration file, specify the key here. If "
|
|
||||||
"the settings are at the top level, leave this blank."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the training target configuration file. This file "
|
|
||||||
"specifies what sounds the model should learn to predict."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the target settings are inside a nested dictionary "
|
|
||||||
"within the target configuration file, specify the key here. "
|
|
||||||
"If the settings are at the top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--force",
|
"--force",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
@ -117,20 +85,32 @@ __all__ = ["preprocess"]
|
|||||||
"the program will use all available cores."
|
"the program will use all available cores."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
count=True,
|
||||||
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
|
)
|
||||||
def preprocess(
|
def preprocess(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
output: Path,
|
output: Path,
|
||||||
target_config: Optional[Path] = None,
|
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
preprocess_config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
label_config: Optional[Path] = None,
|
config_field: Optional[str] = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
target_config_field: Optional[str] = None,
|
|
||||||
preprocess_config_field: Optional[str] = None,
|
|
||||||
label_config_field: Optional[str] = None,
|
|
||||||
dataset_field: Optional[str] = None,
|
dataset_field: Optional[str] = None,
|
||||||
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
|
logger.remove()
|
||||||
|
if verbose == 0:
|
||||||
|
log_level = "WARNING"
|
||||||
|
elif verbose == 1:
|
||||||
|
log_level = "INFO"
|
||||||
|
else:
|
||||||
|
log_level = "DEBUG"
|
||||||
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
logger.info("Starting preprocessing.")
|
logger.info("Starting preprocessing.")
|
||||||
|
|
||||||
output = Path(output)
|
output = Path(output)
|
||||||
@ -139,31 +119,19 @@ def preprocess(
|
|||||||
base_dir = base_dir or Path.cwd()
|
base_dir = base_dir or Path.cwd()
|
||||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||||
|
|
||||||
preprocess = (
|
if config:
|
||||||
load_preprocessing_config(
|
logger.info(
|
||||||
preprocess_config,
|
"Loading preprocessing config from: {config}", config=config
|
||||||
field=preprocess_config_field,
|
|
||||||
)
|
)
|
||||||
if preprocess_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
target = (
|
conf = (
|
||||||
load_target_config(
|
load_train_preprocessing_config(config, field=config_field)
|
||||||
target_config,
|
if config is not None
|
||||||
field=target_config_field,
|
else TrainPreprocessConfig()
|
||||||
)
|
|
||||||
if target_config
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
logger.debug(
|
||||||
label = (
|
"Preprocessing config:\n{conf}",
|
||||||
load_label_config(
|
conf=yaml.dump(conf.model_dump()),
|
||||||
label_config,
|
|
||||||
field=label_config_field,
|
|
||||||
)
|
|
||||||
if label_config
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = load_dataset_from_config(
|
dataset = load_dataset_from_config(
|
||||||
@ -177,20 +145,10 @@ def preprocess(
|
|||||||
num_examples=len(dataset),
|
num_examples=len(dataset),
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(config=target)
|
preprocess_dataset(
|
||||||
preprocessor = build_preprocessor(config=preprocess)
|
|
||||||
labeller = build_clip_labeler(targets, config=label)
|
|
||||||
|
|
||||||
if not output.exists():
|
|
||||||
logger.debug("Creating directory {directory}", directory=output)
|
|
||||||
output.mkdir(parents=True)
|
|
||||||
|
|
||||||
logger.info("Will start preprocessing")
|
|
||||||
preprocess_annotations(
|
|
||||||
dataset,
|
dataset,
|
||||||
output_dir=output,
|
conf,
|
||||||
preprocessor=preprocessor,
|
output=output,
|
||||||
labeller=labeller,
|
force=force,
|
||||||
replace=force,
|
|
||||||
max_workers=num_workers,
|
max_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
@ -5,236 +5,53 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.train import (
|
||||||
ClassificationAccuracy,
|
FullTrainingConfig,
|
||||||
ClassificationMeanAveragePrecision,
|
load_full_training_config,
|
||||||
DetectionAveragePrecision,
|
train,
|
||||||
)
|
)
|
||||||
from batdetect2.models import build_model
|
|
||||||
from batdetect2.models.backbones import load_backbone_config
|
|
||||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
|
||||||
from batdetect2.train import train
|
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
|
||||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
|
||||||
from batdetect2.train.dataset import list_preprocessed_files
|
from batdetect2.train.dataset import list_preprocessed_files
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"train_command",
|
"train_command",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="train")
|
@cli.command(name="train")
|
||||||
@click.option(
|
@click.argument("train_dir", type=click.Path(exists=True))
|
||||||
"--train-examples",
|
@click.option("--val-dir", type=click.Path(exists=True))
|
||||||
type=click.Path(exists=True),
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
required=True,
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--val-examples", type=click.Path(exists=True))
|
@click.option("--train-workers", type=int, default=0)
|
||||||
@click.option(
|
@click.option("--val-workers", type=int, default=0)
|
||||||
"--model-path",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-config-field",
|
|
||||||
type=str,
|
|
||||||
default="train",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--preprocess-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the preprocessing configuration file. This file tells "
|
|
||||||
"the program how to prepare your audio data before training, such "
|
|
||||||
"as resampling or applying filters."
|
|
||||||
),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--preprocess-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the preprocessing settings are inside a nested dictionary "
|
|
||||||
"within the preprocessing configuration file, specify the key "
|
|
||||||
"here to access them. If the preprocessing settings are at the "
|
|
||||||
"top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
default="preprocess",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the training target configuration file. This file "
|
|
||||||
"specifies what sounds the model should learn to predict."
|
|
||||||
),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the target settings are inside a nested dictionary "
|
|
||||||
"within the target configuration file, specify the key here. "
|
|
||||||
"If the settings are at the top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
default="targets",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--postprocess-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--postprocess-config-field",
|
|
||||||
type=str,
|
|
||||||
default="postprocess",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--model-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--model-config-field",
|
|
||||||
type=str,
|
|
||||||
default="model",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-workers",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--val-workers",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
def train_command(
|
def train_command(
|
||||||
train_examples: Path,
|
train_dir: Path,
|
||||||
val_examples: Optional[Path] = None,
|
val_dir: Optional[Path] = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Optional[Path] = None,
|
||||||
train_config: Path = DEFAULT_CONFIG_FILE,
|
config: Optional[Path] = None,
|
||||||
train_config_field: str = "train",
|
config_field: Optional[str] = None,
|
||||||
preprocess_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
preprocess_config_field: str = "preprocess",
|
|
||||||
target_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
target_config_field: str = "targets",
|
|
||||||
postprocess_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
postprocess_config_field: str = "postprocess",
|
|
||||||
model_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
model_config_field: str = "model",
|
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
):
|
):
|
||||||
logger.info("Starting training!")
|
logger.info("Starting training!")
|
||||||
|
|
||||||
try:
|
conf = (
|
||||||
target_config_loaded = load_target_config(
|
load_full_training_config(config, field=config_field)
|
||||||
path=target_config,
|
if config is not None
|
||||||
field=target_config_field,
|
else FullTrainingConfig()
|
||||||
)
|
|
||||||
targets = build_targets(config=target_config_loaded)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded targets info from config file {path}", path=target_config
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load target info from config file, using default"
|
|
||||||
)
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
try:
|
|
||||||
preprocess_config_loaded = load_preprocessing_config(
|
|
||||||
path=preprocess_config,
|
|
||||||
field=preprocess_config_field,
|
|
||||||
)
|
|
||||||
preprocessor = build_preprocessor(preprocess_config_loaded)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded preprocessor from config file {path}", path=target_config
|
|
||||||
)
|
|
||||||
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load preprocessor from config file, using default"
|
|
||||||
)
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_config_loaded = load_backbone_config(
|
|
||||||
path=model_config, field=model_config_field
|
|
||||||
)
|
|
||||||
model = build_model(
|
|
||||||
num_classes=len(targets.class_names),
|
|
||||||
config=model_config_loaded,
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
model = build_model(num_classes=len(targets.class_names))
|
|
||||||
|
|
||||||
try:
|
|
||||||
postprocess_config_loaded = load_postprocess_config(
|
|
||||||
path=postprocess_config,
|
|
||||||
field=postprocess_config_field,
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets=targets,
|
|
||||||
config=postprocess_config_loaded,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded postprocessor from file {path}", path=postprocess_config
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load postprocessor config from file. Using default"
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(targets=targets)
|
|
||||||
|
|
||||||
try:
|
|
||||||
train_config_loaded = load_train_config(
|
|
||||||
path=train_config, field=train_config_field
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded training config from file {path}",
|
|
||||||
path=train_config,
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
train_config_loaded = TrainingConfig()
|
|
||||||
logger.debug("Could not load training config from file. Using default")
|
|
||||||
|
|
||||||
train_files = list_preprocessed_files(train_examples)
|
|
||||||
|
|
||||||
val_files = (
|
|
||||||
None if val_examples is None else list_preprocessed_files(val_examples)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return train(
|
train_examples = list_preprocessed_files(train_dir)
|
||||||
detector=model,
|
val_examples = (
|
||||||
train_examples=train_files, # type: ignore
|
list_preprocessed_files(val_dir) if val_dir is not None else None
|
||||||
val_examples=val_files, # type: ignore
|
)
|
||||||
|
|
||||||
|
train(
|
||||||
|
train_examples=train_examples,
|
||||||
|
val_examples=val_examples,
|
||||||
|
config=conf,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
targets=targets,
|
|
||||||
config=train_config_loaded,
|
|
||||||
callbacks=[
|
|
||||||
ValidationMetrics(
|
|
||||||
metrics=[
|
|
||||||
DetectionAveragePrecision(),
|
|
||||||
ClassificationMeanAveragePrecision(
|
|
||||||
class_names=targets.class_names,
|
|
||||||
),
|
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
],
|
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
)
|
)
|
||||||
|
@ -25,20 +25,9 @@ class BaseConfig(BaseModel):
|
|||||||
|
|
||||||
Inherits from Pydantic's `BaseModel` to provide data validation, parsing,
|
Inherits from Pydantic's `BaseModel` to provide data validation, parsing,
|
||||||
and serialization capabilities.
|
and serialization capabilities.
|
||||||
|
|
||||||
It sets `extra='forbid'` in its model configuration, meaning that any
|
|
||||||
fields provided in a configuration file that are *not* explicitly defined
|
|
||||||
in the specific configuration schema will raise a validation error. This
|
|
||||||
helps catch typos and ensures configurations strictly adhere to the expected
|
|
||||||
structure.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
model_config : ConfigDict
|
|
||||||
Pydantic model configuration dictionary. Set to forbid extra fields.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
@ -27,9 +27,23 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.blocks import ConvBlock
|
from batdetect2.models.blocks import ConvBlock
|
||||||
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
|
from batdetect2.models.bottleneck import (
|
||||||
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
|
BottleneckConfig,
|
||||||
|
build_bottleneck,
|
||||||
|
)
|
||||||
|
from batdetect2.models.decoder import (
|
||||||
|
DEFAULT_DECODER_CONFIG,
|
||||||
|
Decoder,
|
||||||
|
DecoderConfig,
|
||||||
|
build_decoder,
|
||||||
|
)
|
||||||
|
from batdetect2.models.encoder import (
|
||||||
|
DEFAULT_ENCODER_CONFIG,
|
||||||
|
Encoder,
|
||||||
|
EncoderConfig,
|
||||||
|
build_encoder,
|
||||||
|
)
|
||||||
from batdetect2.models.types import BackboneModel
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig):
|
|||||||
|
|
||||||
input_height: int = 128
|
input_height: int = 128
|
||||||
in_channels: int = 1
|
in_channels: int = 1
|
||||||
encoder: Optional[EncoderConfig] = None
|
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||||
bottleneck: Optional[BottleneckConfig] = None
|
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||||
decoder: Optional[DecoderConfig] = None
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
out_channels: int = 32
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
"BlockGroupConfig",
|
"LayerGroupConfig",
|
||||||
"VerticalConv",
|
"VerticalConv",
|
||||||
"FreqCoordConvDownBlock",
|
"FreqCoordConvDownBlock",
|
||||||
"StandardConvDownBlock",
|
"StandardConvDownBlock",
|
||||||
@ -654,16 +654,16 @@ LayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
"BlockGroupConfig",
|
"LayerGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
class BlockGroupConfig(BaseConfig):
|
class LayerGroupConfig(BaseConfig):
|
||||||
block_type: Literal["group"] = "group"
|
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||||
blocks: List[LayerConfig]
|
layers: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
def build_layer_from_config(
|
def build_layer_from_config(
|
||||||
@ -769,13 +769,13 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "group":
|
if config.block_type == "LayerGroup":
|
||||||
current_channels = in_channels
|
current_channels = in_channels
|
||||||
current_height = input_height
|
current_height = input_height
|
||||||
|
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
for block_config in config.blocks:
|
for block_config in config.layers:
|
||||||
block, current_channels, current_height = build_layer_from_config(
|
block, current_channels, current_height = build_layer_from_config(
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
|
@ -26,7 +26,7 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[
|
|||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
|||||||
layers=[
|
layers=[
|
||||||
FreqCoordConvUpConfig(out_channels=64),
|
FreqCoordConvUpConfig(out_channels=64),
|
||||||
FreqCoordConvUpConfig(out_channels=32),
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
BlockGroupConfig(
|
LayerGroupConfig(
|
||||||
blocks=[
|
layers=[
|
||||||
FreqCoordConvUpConfig(out_channels=32),
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
ConvConfig(out_channels=32),
|
ConvConfig(out_channels=32),
|
||||||
]
|
]
|
||||||
|
@ -28,7 +28,7 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[
|
|||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
|||||||
layers=[
|
layers=[
|
||||||
FreqCoordConvDownConfig(out_channels=32),
|
FreqCoordConvDownConfig(out_channels=32),
|
||||||
FreqCoordConvDownConfig(out_channels=64),
|
FreqCoordConvDownConfig(out_channels=64),
|
||||||
BlockGroupConfig(
|
LayerGroupConfig(
|
||||||
blocks=[
|
layers=[
|
||||||
FreqCoordConvDownConfig(out_channels=128),
|
FreqCoordConvDownConfig(out_channels=128),
|
||||||
ConvConfig(out_channels=256),
|
ConvConfig(out_channels=256),
|
||||||
]
|
]
|
||||||
|
@ -15,8 +15,10 @@ from batdetect2.train.augmentations import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper, select_subclip
|
from batdetect2.train.clips import build_clipper, select_subclip
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
TrainerConfig,
|
FullTrainingConfig,
|
||||||
|
PLTrainerConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
|
load_full_training_config,
|
||||||
load_train_config,
|
load_train_config,
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
@ -26,6 +28,7 @@ from batdetect2.train.dataset import (
|
|||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.losses import (
|
from batdetect2.train.losses import (
|
||||||
ClassificationLossConfig,
|
ClassificationLossConfig,
|
||||||
DetectionLossConfig,
|
DetectionLossConfig,
|
||||||
@ -40,7 +43,10 @@ from batdetect2.train.preprocess import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.train import (
|
from batdetect2.train.train import (
|
||||||
build_train_dataset,
|
build_train_dataset,
|
||||||
|
build_train_loader,
|
||||||
|
build_trainer,
|
||||||
build_val_dataset,
|
build_val_dataset,
|
||||||
|
build_val_loader,
|
||||||
train,
|
train,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -50,15 +56,17 @@ __all__ = [
|
|||||||
"DetectionLossConfig",
|
"DetectionLossConfig",
|
||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
|
"FullTrainingConfig",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
"LossConfig",
|
"LossConfig",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
|
"PLTrainerConfig",
|
||||||
"RandomExampleSource",
|
"RandomExampleSource",
|
||||||
"SizeLossConfig",
|
"SizeLossConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
|
"TrainingModule",
|
||||||
"VolumeAugmentationConfig",
|
"VolumeAugmentationConfig",
|
||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
@ -67,9 +75,13 @@ __all__ = [
|
|||||||
"build_clipper",
|
"build_clipper",
|
||||||
"build_loss",
|
"build_loss",
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
|
"build_train_loader",
|
||||||
|
"build_trainer",
|
||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
|
"build_val_loader",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"list_preprocessed_files",
|
"list_preprocessed_files",
|
||||||
|
"load_full_training_config",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
@ -79,6 +91,5 @@ __all__ = [
|
|||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
"train",
|
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
@ -4,26 +4,27 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.models import BackboneConfig
|
||||||
|
from batdetect2.postprocess import PostprocessConfig
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import ClipingConfig
|
from batdetect2.train.clips import ClipingConfig
|
||||||
|
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OptimizerConfig",
|
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
|
"FullTrainingConfig",
|
||||||
|
"load_full_training_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(BaseConfig):
|
class PLTrainerConfig(BaseConfig):
|
||||||
learning_rate: float = 1e-3
|
|
||||||
t_max: int = 100
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerConfig(BaseConfig):
|
|
||||||
accelerator: str = "auto"
|
accelerator: str = "auto"
|
||||||
accumulate_grad_batches: int = 1
|
accumulate_grad_batches: int = 1
|
||||||
deterministic: bool = True
|
deterministic: bool = True
|
||||||
@ -44,20 +45,17 @@ class TrainerConfig(BaseConfig):
|
|||||||
val_check_interval: Optional[Union[int, float]] = None
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(PLTrainerConfig):
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
t_max: int = 100
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
|
|
||||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
|
||||||
|
|
||||||
augmentations: AugmentationsConfig = Field(
|
augmentations: AugmentationsConfig = Field(
|
||||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||||
)
|
)
|
||||||
|
|
||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
@ -65,3 +63,23 @@ def load_train_config(
|
|||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
) -> TrainingConfig:
|
) -> TrainingConfig:
|
||||||
return load_config(path, schema=TrainingConfig, field=field)
|
return load_config(path, schema=TrainingConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
class FullTrainingConfig(BaseConfig):
|
||||||
|
"""Full training configuration."""
|
||||||
|
|
||||||
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_full_training_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> FullTrainingConfig:
|
||||||
|
"""Load the full training configuration."""
|
||||||
|
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import (
|
from batdetect2.models import ModelOutput, build_model
|
||||||
DetectionModel,
|
from batdetect2.postprocess import build_postprocessor
|
||||||
ModelOutput,
|
from batdetect2.preprocess import build_preprocessor
|
||||||
)
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train import TrainExample
|
from batdetect2.train import TrainExample
|
||||||
from batdetect2.train.types import LossProtocol
|
from batdetect2.train.config import FullTrainingConfig
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
@ -19,28 +18,29 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
def __init__(
|
def __init__(self, config: FullTrainingConfig):
|
||||||
self,
|
|
||||||
detector: DetectionModel,
|
|
||||||
loss: LossProtocol,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
learning_rate: float = 0.001,
|
|
||||||
t_max: int = 100,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss = loss
|
# NOTE: Need to convert to vanilla python object so that DVCLive can
|
||||||
self.detector = detector
|
# store it.
|
||||||
self.preprocessor = preprocessor
|
self._config = (
|
||||||
self.targets = targets
|
config.model_dump() if isinstance(config, BaseModel) else config
|
||||||
self.postprocessor = postprocessor
|
)
|
||||||
|
self.save_hyperparameters({"config": self._config})
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self.config = FullTrainingConfig.model_validate(self._config)
|
||||||
self.t_max = t_max
|
self.loss = build_loss(self.config.train.loss)
|
||||||
|
self.targets = build_targets(self.config.targets)
|
||||||
self.save_hyperparameters()
|
self.detector = build_model(
|
||||||
|
num_classes=len(self.targets.class_names),
|
||||||
|
config=self.config.model,
|
||||||
|
)
|
||||||
|
self.preprocessor = build_preprocessor(self.config.preprocess)
|
||||||
|
self.postprocessor = build_postprocessor(
|
||||||
|
self.targets,
|
||||||
|
min_freq=self.preprocessor.min_freq,
|
||||||
|
max_freq=self.preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.detector(spec)
|
return self.detector(spec)
|
||||||
@ -70,6 +70,6 @@ class TrainingModule(L.LightningModule):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
103
src/batdetect2/train/logging.py
Normal file
103
src/batdetect2/train/logging.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
DEFAULT_LOGS_DIR: str = "logs"
|
||||||
|
|
||||||
|
|
||||||
|
class DVCLiveConfig(BaseConfig):
|
||||||
|
logger_type: Literal["dvclive"] = "dvclive"
|
||||||
|
dir: str = DEFAULT_LOGS_DIR
|
||||||
|
run_name: Optional[str] = None
|
||||||
|
prefix: str = ""
|
||||||
|
log_model: Union[bool, Literal["all"]] = False
|
||||||
|
monitor_system: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class CSVLoggerConfig(BaseConfig):
|
||||||
|
logger_type: Literal["csv"] = "csv"
|
||||||
|
save_dir: str = DEFAULT_LOGS_DIR
|
||||||
|
name: Optional[str] = "logs"
|
||||||
|
version: Optional[str] = None
|
||||||
|
flush_logs_every_n_steps: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TensorBoardLoggerConfig(BaseConfig):
|
||||||
|
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||||
|
save_dir: str = DEFAULT_LOGS_DIR
|
||||||
|
name: Optional[str] = "default"
|
||||||
|
version: Optional[str] = None
|
||||||
|
log_graph: bool = False
|
||||||
|
flush_logs_every_n_steps: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
LoggerConfig = Annotated[
|
||||||
|
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
|
||||||
|
Field(discriminator="logger_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||||
|
try:
|
||||||
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
|
except ImportError as error:
|
||||||
|
raise ValueError(
|
||||||
|
"DVCLive is not installed and cannot be used for logging"
|
||||||
|
"Make sure you have it installed by running `pip install dvclive`"
|
||||||
|
"or `uv add dvclive`"
|
||||||
|
) from error
|
||||||
|
|
||||||
|
return DVCLiveLogger(
|
||||||
|
dir=config.dir,
|
||||||
|
run_name=config.run_name,
|
||||||
|
prefix=config.prefix,
|
||||||
|
log_model=config.log_model,
|
||||||
|
monitor_system=config.monitor_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
||||||
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
|
return CSVLogger(
|
||||||
|
save_dir=config.save_dir,
|
||||||
|
name=config.name,
|
||||||
|
version=config.version,
|
||||||
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
return TensorBoardLogger(
|
||||||
|
save_dir=config.save_dir,
|
||||||
|
name=config.name,
|
||||||
|
version=config.version,
|
||||||
|
log_graph=config.log_graph,
|
||||||
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER_FACTORY = {
|
||||||
|
"dvclive": create_dvclive_logger,
|
||||||
|
"csv": create_csv_logger,
|
||||||
|
"tensorboard": create_tensorboard_logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_logger(config: LoggerConfig) -> Logger:
|
||||||
|
"""
|
||||||
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
|
"""
|
||||||
|
logger_type = config.logger_type
|
||||||
|
|
||||||
|
if logger_type not in LOGGER_FACTORY:
|
||||||
|
raise ValueError(f"Unknown logger type: {logger_type}")
|
||||||
|
|
||||||
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
|
return creation_func(config)
|
@ -27,22 +27,71 @@ from typing import Callable, Optional, Sequence
|
|||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.data.datasets import Dataset
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
|
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
"preprocess_single_annotation",
|
"preprocess_single_annotation",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
|
"preprocess_dataset",
|
||||||
|
"TrainPreprocessConfig",
|
||||||
|
"load_train_preprocessing_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||||
"""Type alias for a function that generates an output filename."""
|
"""Type alias for a function that generates an output filename."""
|
||||||
|
|
||||||
|
|
||||||
|
class TrainPreprocessConfig(BaseConfig):
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_preprocessing_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> TrainPreprocessConfig:
|
||||||
|
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_dataset(
|
||||||
|
dataset: Dataset,
|
||||||
|
config: TrainPreprocessConfig,
|
||||||
|
output: Path,
|
||||||
|
force: bool = False,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
targets = build_targets(config=config.targets)
|
||||||
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
|
labeller = build_clip_labeler(targets, config=config.labels)
|
||||||
|
|
||||||
|
if not output.exists():
|
||||||
|
logger.debug("Creating directory {directory}", directory=output)
|
||||||
|
output.mkdir(parents=True)
|
||||||
|
|
||||||
|
preprocess_annotations(
|
||||||
|
dataset,
|
||||||
|
output_dir=output,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
labeller=labeller,
|
||||||
|
replace=force,
|
||||||
|
max_workers=max_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
@ -207,8 +256,17 @@ def preprocess_annotations(
|
|||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
|
logger.info(
|
||||||
|
"Creating output directory: {output_dir}", output_dir=output_dir
|
||||||
|
)
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
|
||||||
|
num_annotations=len(clip_annotations),
|
||||||
|
max_workers=max_workers or "all available",
|
||||||
|
)
|
||||||
|
|
||||||
with Pool(max_workers) as pool:
|
with Pool(max_workers) as pool:
|
||||||
list(
|
list(
|
||||||
tqdm(
|
tqdm(
|
||||||
@ -224,8 +282,10 @@ def preprocess_annotations(
|
|||||||
clip_annotations,
|
clip_annotations,
|
||||||
),
|
),
|
||||||
total=len(clip_annotations),
|
total=len(clip_annotations),
|
||||||
|
desc="Preprocessing annotations",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
logger.info("Finished preprocessing.")
|
||||||
|
|
||||||
|
|
||||||
def preprocess_single_annotation(
|
def preprocess_single_annotation(
|
||||||
@ -264,11 +324,15 @@ def preprocess_single_annotation(
|
|||||||
path = output_dir / filename
|
path = output_dir / filename
|
||||||
|
|
||||||
if path.is_file() and not replace:
|
if path.is_file() and not replace:
|
||||||
|
logger.debug("Skipping existing file: {path}", path=path)
|
||||||
return
|
return
|
||||||
|
|
||||||
if path.is_file() and replace:
|
if path.is_file() and replace:
|
||||||
|
logger.debug("Removing existing file: {path}", path=path)
|
||||||
path.unlink()
|
path.unlink()
|
||||||
|
|
||||||
|
logger.debug("Processing annotation {uuid}", uuid=clip_annotation.uuid)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sample = generate_train_example(
|
sample = generate_train_example(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
@ -277,8 +341,9 @@ def preprocess_single_annotation(
|
|||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to process annotation: {uuid}. Error {error}",
|
"Failed to process annotation {uuid} to {path}. Error: {error}",
|
||||||
uuid=clip_annotation.uuid,
|
uuid=clip_annotation.uuid,
|
||||||
|
path=path,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
@ -5,109 +6,74 @@ from lightning.pytorch.callbacks import Callback
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.evaluate.metrics import (
|
||||||
from batdetect2.postprocess import build_postprocessor
|
ClassificationAccuracy,
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
ClassificationMeanAveragePrecision,
|
||||||
from batdetect2.preprocess import build_preprocessor
|
DetectionAveragePrecision,
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train.augmentations import (
|
|
||||||
build_augmentations,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.preprocess import (
|
||||||
|
PreprocessorProtocol,
|
||||||
|
)
|
||||||
|
from batdetect2.targets import TargetProtocol
|
||||||
|
from batdetect2.train.augmentations import build_augmentations
|
||||||
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import (
|
||||||
|
FullTrainingConfig,
|
||||||
|
PLTrainerConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
collate_fn,
|
collate_fn,
|
||||||
)
|
)
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.logging import build_logger
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"train",
|
|
||||||
"build_val_dataset",
|
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
|
"build_train_loader",
|
||||||
|
"build_trainer",
|
||||||
|
"build_val_dataset",
|
||||||
|
"build_val_loader",
|
||||||
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
detector: DetectionModel,
|
train_examples: Sequence[data.PathLike],
|
||||||
train_examples: List[data.PathLike],
|
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
|
||||||
val_examples: Optional[List[data.PathLike]] = None,
|
|
||||||
config: Optional[TrainingConfig] = None,
|
|
||||||
callbacks: Optional[List[Callback]] = None,
|
|
||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
**trainer_kwargs,
|
):
|
||||||
) -> None:
|
conf = config or FullTrainingConfig()
|
||||||
config = config or TrainingConfig()
|
|
||||||
if model_path is None:
|
|
||||||
if preprocessor is None:
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
if targets is None:
|
if model_path is not None:
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
if postprocessor is None:
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = build_loss(config.loss)
|
|
||||||
|
|
||||||
module = TrainingModule(
|
|
||||||
detector=detector,
|
|
||||||
loss=loss,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
learning_rate=config.optimizer.learning_rate,
|
|
||||||
t_max=config.optimizer.t_max,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
|
else:
|
||||||
|
module = TrainingModule(conf)
|
||||||
|
|
||||||
train_dataset = build_train_dataset(
|
trainer = build_trainer(conf, targets=module.targets)
|
||||||
|
|
||||||
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor=module.preprocessor,
|
preprocessor=module.preprocessor,
|
||||||
config=config,
|
config=conf.train,
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
**config.trainer.model_dump(exclude_none=True),
|
|
||||||
callbacks=callbacks,
|
|
||||||
**trainer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = None
|
val_dataloader = (
|
||||||
if val_examples:
|
build_val_loader(
|
||||||
val_dataset = build_val_dataset(
|
|
||||||
val_examples,
|
val_examples,
|
||||||
config=config,
|
config=conf.train,
|
||||||
)
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
)
|
||||||
|
if val_examples is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
@ -116,8 +82,76 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||||
|
return [
|
||||||
|
ValidationMetrics(
|
||||||
|
metrics=[
|
||||||
|
DetectionAveragePrecision(),
|
||||||
|
ClassificationMeanAveragePrecision(
|
||||||
|
class_names=targets.class_names
|
||||||
|
),
|
||||||
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_trainer(
|
||||||
|
conf: FullTrainingConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> Trainer:
|
||||||
|
trainer_conf = PLTrainerConfig.model_validate(
|
||||||
|
conf.train.model_dump(mode="python")
|
||||||
|
)
|
||||||
|
return Trainer(
|
||||||
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
|
val_check_interval=conf.train.val_check_interval,
|
||||||
|
logger=build_logger(conf.train.logger),
|
||||||
|
callbacks=build_trainer_callbacks(targets),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_train_loader(
|
||||||
|
train_examples: Sequence[data.PathLike],
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
config: TrainingConfig,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
train_dataset = build_train_dataset(
|
||||||
|
train_examples,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers or 0,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_val_loader(
|
||||||
|
val_examples: Sequence[data.PathLike],
|
||||||
|
config: TrainingConfig,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
val_dataset = build_val_dataset(
|
||||||
|
val_examples,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers or 0,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_train_dataset(
|
def build_train_dataset(
|
||||||
examples: List[data.PathLike],
|
examples: Sequence[data.PathLike],
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
) -> LabeledDataset:
|
) -> LabeledDataset:
|
||||||
@ -126,7 +160,7 @@ def build_train_dataset(
|
|||||||
clipper = build_clipper(config.cliping, random=True)
|
clipper = build_clipper(config.cliping, random=True)
|
||||||
|
|
||||||
random_example_source = RandomExampleSource(
|
random_example_source = RandomExampleSource(
|
||||||
examples,
|
list(examples),
|
||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -144,7 +178,7 @@ def build_train_dataset(
|
|||||||
|
|
||||||
|
|
||||||
def build_val_dataset(
|
def build_val_dataset(
|
||||||
examples: List[data.PathLike],
|
examples: Sequence[data.PathLike],
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
train: bool = True,
|
train: bool = True,
|
||||||
) -> LabeledDataset:
|
) -> LabeledDataset:
|
||||||
|
10
tests/test_train/test_config.py
Normal file
10
tests/test_train/test_config.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from batdetect2.configs import load_config
|
||||||
|
from batdetect2.train import FullTrainingConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_example_config_is_valid(example_data_dir):
|
||||||
|
conf = load_config(
|
||||||
|
example_data_dir / "config.yaml",
|
||||||
|
schema=FullTrainingConfig,
|
||||||
|
)
|
||||||
|
assert isinstance(conf, FullTrainingConfig)
|
@ -5,27 +5,11 @@ import torch
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models import build_model
|
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||||
from batdetect2.postprocess import build_postprocessor
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.train.lightning import TrainingModule
|
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
loss = build_loss()
|
return TrainingModule(FullTrainingConfig())
|
||||||
targets = build_targets()
|
|
||||||
detector = build_model(num_classes=len(targets.class_names))
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
postprocessor = build_postprocessor(targets)
|
|
||||||
return TrainingModule(
|
|
||||||
detector=detector,
|
|
||||||
loss=loss,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
|
Loading…
Reference in New Issue
Block a user