mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 07:02:01 +02:00
Compare commits
23 Commits
a462beaeb8
...
e8db1d4050
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e8db1d4050 | ||
![]() |
b396d4908a | ||
![]() |
ce15a0f152 | ||
![]() |
16febed792 | ||
![]() |
d67ae9be05 | ||
![]() |
15de168a20 | ||
![]() |
0c8fae4a72 | ||
![]() |
6d57f96c07 | ||
![]() |
a0622aa9a4 | ||
![]() |
587742b41e | ||
![]() |
6d91153a56 | ||
![]() |
70c96b6844 | ||
![]() |
22f7d46f46 | ||
![]() |
1384c549f7 | ||
![]() |
4b6acd5e6e | ||
![]() |
2ac968d65b | ||
![]() |
166dad20bd | ||
![]() |
cbb02cf69e | ||
![]() |
647468123e | ||
![]() |
363eb9fb2f | ||
![]() |
b21c224985 | ||
![]() |
152a577511 | ||
![]() |
136949c4e7 |
3
.dvc/.gitignore
vendored
Normal file
3
.dvc/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
/config.local
|
||||
/tmp
|
||||
/cache
|
0
.dvc/config
Normal file
0
.dvc/config
Normal file
3
.dvcignore
Normal file
3
.dvcignore
Normal file
@ -0,0 +1,3 @@
|
||||
# Add patterns of files dvc should ignore, which could improve
|
||||
# the performance. Learn more at
|
||||
# https://dvc.org/doc/user-guide/dvcignore
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -114,3 +114,5 @@ experiments/*
|
||||
notebooks/lightning_logs
|
||||
example_data/preprocessed
|
||||
.aider*
|
||||
DvcLiveLogger/checkpoints
|
||||
logs
|
||||
|
94
Makefile
94
Makefile
@ -1,94 +0,0 @@
|
||||
# Variables
|
||||
SOURCE_DIR = src
|
||||
TESTS_DIR = tests
|
||||
PYTHON_DIRS = src tests
|
||||
DOCS_SOURCE = docs/source
|
||||
DOCS_BUILD = docs/build
|
||||
HTML_COVERAGE_DIR = htmlcov
|
||||
|
||||
# Default target (optional, often 'help' or 'all')
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
# Phony targets (targets that don't produce a file with the same name)
|
||||
.PHONY: help test coverage coverage-html coverage-serve docs docs-serve format format-check lint lint-fix typecheck check clean clean-pyc clean-test clean-docs clean-build
|
||||
|
||||
help:
|
||||
@echo "Makefile Targets:"
|
||||
@echo " help Show this help message."
|
||||
@echo " test Run tests using pytest."
|
||||
@echo " coverage Run tests and generate coverage data (.coverage, coverage.xml)."
|
||||
@echo " coverage-html Generate an HTML coverage report in $(HTML_COVERAGE_DIR)/."
|
||||
@echo " coverage-serve Serve the HTML coverage report locally."
|
||||
@echo " docs Build documentation using Sphinx."
|
||||
@echo " docs-serve Serve documentation with live reload using sphinx-autobuild."
|
||||
@echo " format Format code using ruff."
|
||||
@echo " format-check Check code formatting using ruff."
|
||||
@echo " lint Lint code using ruff."
|
||||
@echo " lint-fix Lint code using ruff and apply automatic fixes."
|
||||
@echo " typecheck Type check code using pyright."
|
||||
@echo " check Run all checks (format-check, lint, typecheck)."
|
||||
@echo " clean Remove all build, test, documentation, and Python artifacts."
|
||||
@echo " clean-pyc Remove Python bytecode and cache."
|
||||
@echo " clean-test Remove test and coverage artifacts."
|
||||
@echo " clean-docs Remove built documentation."
|
||||
@echo " clean-build Remove package build artifacts."
|
||||
|
||||
# Testing & Coverage
|
||||
test:
|
||||
pytest tests
|
||||
|
||||
coverage:
|
||||
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml tests
|
||||
|
||||
coverage-html: coverage
|
||||
@echo "Generating HTML coverage report..."
|
||||
coverage html -d $(HTML_COVERAGE_DIR)
|
||||
@echo "HTML coverage report generated in $(HTML_COVERAGE_DIR)/"
|
||||
|
||||
coverage-serve: coverage-html
|
||||
@echo "Serving report at http://localhost:8000/ ..."
|
||||
python -m http.server --directory $(HTML_COVERAGE_DIR) 8000
|
||||
|
||||
# Documentation
|
||||
docs:
|
||||
sphinx-build -b html $(DOCS_SOURCE) $(DOCS_BUILD)
|
||||
|
||||
docs-serve:
|
||||
sphinx-autobuild $(DOCS_SOURCE) $(DOCS_BUILD) --watch $(SOURCE_DIR) --open-browser
|
||||
|
||||
# Formatting & Linting
|
||||
format:
|
||||
ruff format $(PYTHON_DIRS)
|
||||
|
||||
format-check:
|
||||
ruff format --check $(PYTHON_DIRS)
|
||||
|
||||
lint:
|
||||
ruff check $(PYTHON_DIRS)
|
||||
|
||||
lint-fix:
|
||||
ruff check --fix $(PYTHON_DIRS)
|
||||
|
||||
# Type Checking
|
||||
typecheck:
|
||||
pyright $(PYTHON_DIRS)
|
||||
|
||||
# Combined Checks
|
||||
check: format-check lint typecheck test
|
||||
|
||||
# Cleaning tasks
|
||||
clean-pyc:
|
||||
find . -type f -name "*.py[co]" -delete
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} +
|
||||
|
||||
clean-test:
|
||||
rm -f .coverage coverage.xml
|
||||
rm -rf .pytest_cache htmlcov/
|
||||
|
||||
clean-docs:
|
||||
rm -rf $(DOCS_BUILD)
|
||||
|
||||
clean-build:
|
||||
rm -rf build/ dist/ *.egg-info/
|
||||
|
||||
clean: clean-build clean-pyc clean-test clean-docs
|
146
example_data/config.yaml
Normal file
146
example_data/config.yaml
Normal file
@ -0,0 +1,146 @@
|
||||
targets:
|
||||
classes:
|
||||
classes:
|
||||
- name: myomys
|
||||
tags:
|
||||
- value: Myotis mystacinus
|
||||
- name: pippip
|
||||
tags:
|
||||
- value: Pipistrellus pipistrellus
|
||||
- name: eptser
|
||||
tags:
|
||||
- value: Eptesicus serotinus
|
||||
- name: rhifer
|
||||
tags:
|
||||
- value: Rhinolophus ferrumequinum
|
||||
generic_class:
|
||||
- key: class
|
||||
value: Bat
|
||||
|
||||
filtering:
|
||||
rules:
|
||||
- match_type: all
|
||||
tags:
|
||||
- key: event
|
||||
value: Echolocation
|
||||
- match_type: exclude
|
||||
tags:
|
||||
- key: class
|
||||
value: Unknown
|
||||
|
||||
preprocess:
|
||||
audio:
|
||||
resample:
|
||||
samplerate: 256000
|
||||
method: "poly"
|
||||
scale: false
|
||||
center: true
|
||||
duration: null
|
||||
|
||||
spectrogram:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
pcen:
|
||||
time_constant: 0.4
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
scale: "amplitude"
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectral_mean_substraction: true
|
||||
peak_normalize: false
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
min_freq: 10000
|
||||
max_freq: 120000
|
||||
top_k_per_sec: 200
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
model:
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
out_channels: 32
|
||||
encoder:
|
||||
layers:
|
||||
- block_type: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- block_type: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- block_type: LayerGroup
|
||||
layers:
|
||||
- block_type: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- block_type: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
self_attention: true
|
||||
decoder:
|
||||
layers:
|
||||
- block_type: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- block_type: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- block_type: LayerGroup
|
||||
layers:
|
||||
- block_type: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- block_type: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
train:
|
||||
batch_size: 8
|
||||
learning_rate: 0.001
|
||||
t_max: 100
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
classification:
|
||||
weight: 2.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
logger:
|
||||
logger_type: dvclive
|
||||
augmentations:
|
||||
steps:
|
||||
- augmentation_type: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- augmentation_type: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
- augmentation_type: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- augmentation_type: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- augmentation_type: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- augmentation_type: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
110
justfile
Normal file
110
justfile
Normal file
@ -0,0 +1,110 @@
|
||||
# Default command, runs if no recipe is specified.
|
||||
default:
|
||||
just --list
|
||||
|
||||
# Variables
|
||||
SOURCE_DIR := "src"
|
||||
TESTS_DIR := "tests"
|
||||
PYTHON_DIRS := "src tests"
|
||||
DOCS_SOURCE := "docs/source"
|
||||
DOCS_BUILD := "docs/build"
|
||||
HTML_COVERAGE_DIR := "htmlcov"
|
||||
|
||||
# Show available commands
|
||||
help:
|
||||
@just --list
|
||||
|
||||
# Testing & Coverage
|
||||
# Run tests using pytest.
|
||||
test:
|
||||
pytest {{TESTS_DIR}}
|
||||
|
||||
# Run tests and generate coverage data.
|
||||
coverage:
|
||||
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||
|
||||
# Generate an HTML coverage report.
|
||||
coverage-html: coverage
|
||||
@echo "Generating HTML coverage report..."
|
||||
coverage html -d {{HTML_COVERAGE_DIR}}
|
||||
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
||||
|
||||
# Serve the HTML coverage report locally.
|
||||
coverage-serve: coverage-html
|
||||
@echo "Serving report at http://localhost:8000/ ..."
|
||||
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||
|
||||
# Documentation
|
||||
# Build documentation using Sphinx.
|
||||
docs:
|
||||
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||
|
||||
# Serve documentation with live reload.
|
||||
docs-serve:
|
||||
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||
|
||||
# Formatting & Linting
|
||||
# Format code using ruff.
|
||||
format:
|
||||
ruff format {{PYTHON_DIRS}}
|
||||
|
||||
# Check code formatting using ruff.
|
||||
format-check:
|
||||
ruff format --check {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff.
|
||||
lint:
|
||||
ruff check {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff and apply automatic fixes.
|
||||
lint-fix:
|
||||
ruff check --fix {{PYTHON_DIRS}}
|
||||
|
||||
# Type Checking
|
||||
# Type check code using pyright.
|
||||
typecheck:
|
||||
pyright {{PYTHON_DIRS}}
|
||||
|
||||
# Combined Checks
|
||||
# Run all checks (format-check, lint, typecheck).
|
||||
check: format-check lint typecheck test
|
||||
|
||||
# Cleaning tasks
|
||||
# Remove Python bytecode and cache.
|
||||
clean-pyc:
|
||||
find . -type f -name "*.py[co]" -delete
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} +
|
||||
|
||||
# Remove test and coverage artifacts.
|
||||
clean-test:
|
||||
rm -f .coverage coverage.xml
|
||||
rm -rf .pytest_cache htmlcov/
|
||||
|
||||
# Remove built documentation.
|
||||
clean-docs:
|
||||
rm -rf {{DOCS_BUILD}}
|
||||
|
||||
# Remove package build artifacts.
|
||||
clean-build:
|
||||
rm -rf build/ dist/ *.egg-info/
|
||||
|
||||
# Remove all build, test, documentation, and Python artifacts.
|
||||
clean: clean-build clean-pyc clean-test clean-docs
|
||||
|
||||
# Examples
|
||||
# Preprocess example data.
|
||||
example-preprocess OPTIONS="":
|
||||
batdetect2 preprocess \
|
||||
--base-dir . \
|
||||
--dataset-field datasets.train \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/datasets.yaml example_data/preprocessed
|
||||
|
||||
# Train on example data.
|
||||
example-train OPTIONS="":
|
||||
batdetect2 train \
|
||||
--val-dir example_data/preprocessed \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/preprocessed
|
@ -85,6 +85,11 @@ dev = [
|
||||
"sphinx-book-theme>=1.1.4",
|
||||
"autodoc-pydantic>=2.2.0",
|
||||
"pytest-cov>=6.1.1",
|
||||
"ty>=0.0.1a12",
|
||||
"rust-just>=1.40.0",
|
||||
]
|
||||
dvclive = [
|
||||
"dvclive>=3.48.2",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -2,13 +2,13 @@ from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.preprocess import preprocess
|
||||
from batdetect2.cli.train import train
|
||||
from batdetect2.cli.train import train_command
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
"detect",
|
||||
"data",
|
||||
"train",
|
||||
"train_command",
|
||||
"preprocess",
|
||||
]
|
||||
|
||||
|
@ -1,15 +1,18 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train import load_label_config, preprocess_annotations
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.preprocess import (
|
||||
TrainPreprocessConfig,
|
||||
load_train_preprocessing_config,
|
||||
preprocess_dataset,
|
||||
)
|
||||
|
||||
__all__ = ["preprocess"]
|
||||
|
||||
@ -44,16 +47,16 @@ __all__ = ["preprocess"]
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config",
|
||||
"--config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the preprocessing configuration file. This file tells "
|
||||
"Path to the configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config-field",
|
||||
"--config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
@ -62,41 +65,6 @@ __all__ = ["preprocess"]
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--label-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the label generation configuration file. This file "
|
||||
"contains settings for how to create labels from your "
|
||||
"annotations, which the model uses to learn."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--label-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the label generation settings are inside a nested dictionary "
|
||||
"within the label configuration file, specify the key here. If "
|
||||
"the settings are at the top level, leave this blank."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the training target configuration file. This file "
|
||||
"specifies what sounds the model should learn to predict."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the target settings are inside a nested dictionary "
|
||||
"within the target configuration file, specify the key here. "
|
||||
"If the settings are at the top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--force",
|
||||
is_flag=True,
|
||||
@ -117,20 +85,32 @@ __all__ = ["preprocess"]
|
||||
"the program will use all available cores."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
target_config: Optional[Path] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
preprocess_config: Optional[Path] = None,
|
||||
label_config: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
force: bool = False,
|
||||
num_workers: Optional[int] = None,
|
||||
target_config_field: Optional[str] = None,
|
||||
preprocess_config_field: Optional[str] = None,
|
||||
label_config_field: Optional[str] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Starting preprocessing.")
|
||||
|
||||
output = Path(output)
|
||||
@ -139,31 +119,19 @@ def preprocess(
|
||||
base_dir = base_dir or Path.cwd()
|
||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||
|
||||
preprocess = (
|
||||
load_preprocessing_config(
|
||||
preprocess_config,
|
||||
field=preprocess_config_field,
|
||||
if config:
|
||||
logger.info(
|
||||
"Loading preprocessing config from: {config}", config=config
|
||||
)
|
||||
if preprocess_config
|
||||
else None
|
||||
)
|
||||
|
||||
target = (
|
||||
load_target_config(
|
||||
target_config,
|
||||
field=target_config_field,
|
||||
)
|
||||
if target_config
|
||||
else None
|
||||
conf = (
|
||||
load_train_preprocessing_config(config, field=config_field)
|
||||
if config is not None
|
||||
else TrainPreprocessConfig()
|
||||
)
|
||||
|
||||
label = (
|
||||
load_label_config(
|
||||
label_config,
|
||||
field=label_config_field,
|
||||
)
|
||||
if label_config
|
||||
else None
|
||||
logger.debug(
|
||||
"Preprocessing config:\n{conf}",
|
||||
conf=yaml.dump(conf.model_dump()),
|
||||
)
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
@ -177,20 +145,10 @@ def preprocess(
|
||||
num_examples=len(dataset),
|
||||
)
|
||||
|
||||
targets = build_targets(config=target)
|
||||
preprocessor = build_preprocessor(config=preprocess)
|
||||
labeller = build_clip_labeler(targets, config=label)
|
||||
|
||||
if not output.exists():
|
||||
logger.debug("Creating directory {directory}", directory=output)
|
||||
output.mkdir(parents=True)
|
||||
|
||||
logger.info("Will start preprocessing")
|
||||
preprocess_annotations(
|
||||
preprocess_dataset(
|
||||
dataset,
|
||||
output_dir=output,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
replace=force,
|
||||
conf,
|
||||
output=output,
|
||||
force=force,
|
||||
max_workers=num_workers,
|
||||
)
|
||||
|
@ -5,236 +5,53 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.models.backbones import load_backbone_config
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train import train
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import list_preprocessed_files
|
||||
|
||||
__all__ = [
|
||||
"train_command",
|
||||
]
|
||||
|
||||
DEFAULT_CONFIG_FILE = Path("config.yaml")
|
||||
|
||||
|
||||
@cli.command(name="train")
|
||||
@click.option(
|
||||
"--train-examples",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
)
|
||||
@click.option("--val-examples", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--model-path",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
"--train-config",
|
||||
type=click.Path(exists=True),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--train-config-field",
|
||||
type=str,
|
||||
default="train",
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the preprocessing configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
"within the preprocessing configuration file, specify the key "
|
||||
"here to access them. If the preprocessing settings are at the "
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
default="preprocess",
|
||||
)
|
||||
@click.option(
|
||||
"--target-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the training target configuration file. This file "
|
||||
"specifies what sounds the model should learn to predict."
|
||||
),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--target-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the target settings are inside a nested dictionary "
|
||||
"within the target configuration file, specify the key here. "
|
||||
"If the settings are at the top level, you don't need to specify this."
|
||||
),
|
||||
default="targets",
|
||||
)
|
||||
@click.option(
|
||||
"--postprocess-config",
|
||||
type=click.Path(exists=True),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--postprocess-config-field",
|
||||
type=str,
|
||||
default="postprocess",
|
||||
)
|
||||
@click.option(
|
||||
"--model-config",
|
||||
type=click.Path(exists=True),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--model-config-field",
|
||||
type=str,
|
||||
default="model",
|
||||
)
|
||||
@click.option(
|
||||
"--train-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
@click.option(
|
||||
"--val-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
@click.argument("train_dir", type=click.Path(exists=True))
|
||||
@click.option("--val-dir", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int, default=0)
|
||||
@click.option("--val-workers", type=int, default=0)
|
||||
def train_command(
|
||||
train_examples: Path,
|
||||
val_examples: Optional[Path] = None,
|
||||
train_dir: Path,
|
||||
val_dir: Optional[Path] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
train_config: Path = DEFAULT_CONFIG_FILE,
|
||||
train_config_field: str = "train",
|
||||
preprocess_config: Path = DEFAULT_CONFIG_FILE,
|
||||
preprocess_config_field: str = "preprocess",
|
||||
target_config: Path = DEFAULT_CONFIG_FILE,
|
||||
target_config_field: str = "targets",
|
||||
postprocess_config: Path = DEFAULT_CONFIG_FILE,
|
||||
postprocess_config_field: str = "postprocess",
|
||||
model_config: Path = DEFAULT_CONFIG_FILE,
|
||||
model_config_field: str = "model",
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
):
|
||||
logger.info("Starting training!")
|
||||
|
||||
try:
|
||||
target_config_loaded = load_target_config(
|
||||
path=target_config,
|
||||
field=target_config_field,
|
||||
)
|
||||
targets = build_targets(config=target_config_loaded)
|
||||
logger.debug(
|
||||
"Loaded targets info from config file {path}", path=target_config
|
||||
)
|
||||
except IOError:
|
||||
logger.debug(
|
||||
"Could not load target info from config file, using default"
|
||||
)
|
||||
targets = build_targets()
|
||||
|
||||
try:
|
||||
preprocess_config_loaded = load_preprocessing_config(
|
||||
path=preprocess_config,
|
||||
field=preprocess_config_field,
|
||||
)
|
||||
preprocessor = build_preprocessor(preprocess_config_loaded)
|
||||
logger.debug(
|
||||
"Loaded preprocessor from config file {path}", path=target_config
|
||||
)
|
||||
|
||||
except IOError:
|
||||
logger.debug(
|
||||
"Could not load preprocessor from config file, using default"
|
||||
)
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
try:
|
||||
model_config_loaded = load_backbone_config(
|
||||
path=model_config, field=model_config_field
|
||||
)
|
||||
model = build_model(
|
||||
num_classes=len(targets.class_names),
|
||||
config=model_config_loaded,
|
||||
)
|
||||
except IOError:
|
||||
model = build_model(num_classes=len(targets.class_names))
|
||||
|
||||
try:
|
||||
postprocess_config_loaded = load_postprocess_config(
|
||||
path=postprocess_config,
|
||||
field=postprocess_config_field,
|
||||
)
|
||||
postprocessor = build_postprocessor(
|
||||
targets=targets,
|
||||
config=postprocess_config_loaded,
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded postprocessor from file {path}", path=postprocess_config
|
||||
)
|
||||
except IOError:
|
||||
logger.debug(
|
||||
"Could not load postprocessor config from file. Using default"
|
||||
)
|
||||
postprocessor = build_postprocessor(targets=targets)
|
||||
|
||||
try:
|
||||
train_config_loaded = load_train_config(
|
||||
path=train_config, field=train_config_field
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded training config from file {path}",
|
||||
path=train_config,
|
||||
)
|
||||
except IOError:
|
||||
train_config_loaded = TrainingConfig()
|
||||
logger.debug("Could not load training config from file. Using default")
|
||||
|
||||
train_files = list_preprocessed_files(train_examples)
|
||||
|
||||
val_files = (
|
||||
None if val_examples is None else list_preprocessed_files(val_examples)
|
||||
conf = (
|
||||
load_full_training_config(config, field=config_field)
|
||||
if config is not None
|
||||
else FullTrainingConfig()
|
||||
)
|
||||
|
||||
return train(
|
||||
detector=model,
|
||||
train_examples=train_files, # type: ignore
|
||||
val_examples=val_files, # type: ignore
|
||||
train_examples = list_preprocessed_files(train_dir)
|
||||
val_examples = (
|
||||
list_preprocessed_files(val_dir) if val_dir is not None else None
|
||||
)
|
||||
|
||||
train(
|
||||
train_examples=train_examples,
|
||||
val_examples=val_examples,
|
||||
config=conf,
|
||||
model_path=model_path,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
targets=targets,
|
||||
config=train_config_loaded,
|
||||
callbacks=[
|
||||
ValidationMetrics(
|
||||
metrics=[
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(
|
||||
class_names=targets.class_names,
|
||||
),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
)
|
||||
],
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
)
|
||||
|
@ -25,20 +25,9 @@ class BaseConfig(BaseModel):
|
||||
|
||||
Inherits from Pydantic's `BaseModel` to provide data validation, parsing,
|
||||
and serialization capabilities.
|
||||
|
||||
It sets `extra='forbid'` in its model configuration, meaning that any
|
||||
fields provided in a configuration file that are *not* explicitly defined
|
||||
in the specific configuration schema will raise a validation error. This
|
||||
helps catch typos and ensures configurations strictly adhere to the expected
|
||||
structure.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
model_config : ConfigDict
|
||||
Pydantic model configuration dictionary. Set to forbid extra fields.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
@ -27,9 +27,23 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.blocks import ConvBlock
|
||||
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
|
||||
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
Decoder,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
Encoder,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig):
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: Optional[EncoderConfig] = None
|
||||
bottleneck: Optional[BottleneckConfig] = None
|
||||
decoder: Optional[DecoderConfig] = None
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
"BlockGroupConfig",
|
||||
"LayerGroupConfig",
|
||||
"VerticalConv",
|
||||
"FreqCoordConvDownBlock",
|
||||
"StandardConvDownBlock",
|
||||
@ -654,16 +654,16 @@ LayerConfig = Annotated[
|
||||
StandardConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
"BlockGroupConfig",
|
||||
"LayerGroupConfig",
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class BlockGroupConfig(BaseConfig):
|
||||
block_type: Literal["group"] = "group"
|
||||
blocks: List[LayerConfig]
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: List[LayerConfig]
|
||||
|
||||
|
||||
def build_layer_from_config(
|
||||
@ -769,13 +769,13 @@ def build_layer_from_config(
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "group":
|
||||
if config.block_type == "LayerGroup":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
blocks = []
|
||||
|
||||
for block_config in config.blocks:
|
||||
for block_config in config.layers:
|
||||
block, current_channels, current_height = build_layer_from_config(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
|
@ -26,7 +26,7 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvUpConfig(out_channels=64),
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
LayerGroupConfig(
|
||||
layers=[
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
ConvConfig(out_channels=32),
|
||||
]
|
||||
|
@ -28,7 +28,7 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvDownConfig(out_channels=32),
|
||||
FreqCoordConvDownConfig(out_channels=64),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
LayerGroupConfig(
|
||||
layers=[
|
||||
FreqCoordConvDownConfig(out_channels=128),
|
||||
ConvConfig(out_channels=256),
|
||||
]
|
||||
|
@ -15,8 +15,10 @@ from batdetect2.train.augmentations import (
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import (
|
||||
TrainerConfig,
|
||||
FullTrainingConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
load_full_training_config,
|
||||
load_train_config,
|
||||
)
|
||||
from batdetect2.train.dataset import (
|
||||
@ -26,6 +28,7 @@ from batdetect2.train.dataset import (
|
||||
list_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import (
|
||||
ClassificationLossConfig,
|
||||
DetectionLossConfig,
|
||||
@ -40,7 +43,10 @@ from batdetect2.train.preprocess import (
|
||||
)
|
||||
from batdetect2.train.train import (
|
||||
build_train_dataset,
|
||||
build_train_loader,
|
||||
build_trainer,
|
||||
build_val_dataset,
|
||||
build_val_loader,
|
||||
train,
|
||||
)
|
||||
|
||||
@ -50,15 +56,17 @@ __all__ = [
|
||||
"DetectionLossConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"FullTrainingConfig",
|
||||
"LabeledDataset",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"PLTrainerConfig",
|
||||
"RandomExampleSource",
|
||||
"SizeLossConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"TrainingModule",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"add_echo",
|
||||
@ -67,9 +75,13 @@ __all__ = [
|
||||
"build_clipper",
|
||||
"build_loss",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"generate_train_example",
|
||||
"list_preprocessed_files",
|
||||
"load_full_training_config",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"mask_frequency",
|
||||
@ -79,6 +91,5 @@ __all__ = [
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
]
|
||||
|
@ -4,26 +4,27 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models import BackboneConfig
|
||||
from batdetect2.postprocess import PostprocessConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
)
|
||||
from batdetect2.train.clips import ClipingConfig
|
||||
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
"OptimizerConfig",
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
"FullTrainingConfig",
|
||||
"load_full_training_config",
|
||||
]
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainerConfig(BaseConfig):
|
||||
class PLTrainerConfig(BaseConfig):
|
||||
accelerator: str = "auto"
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
@ -44,20 +45,17 @@ class TrainerConfig(BaseConfig):
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
class TrainingConfig(PLTrainerConfig):
|
||||
batch_size: int = 8
|
||||
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||
)
|
||||
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
|
||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
@ -65,3 +63,23 @@ def load_train_config(
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
||||
|
||||
|
||||
class FullTrainingConfig(BaseConfig):
|
||||
"""Full training configuration."""
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
|
||||
|
||||
def load_full_training_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> FullTrainingConfig:
|
||||
"""Load the full training configuration."""
|
||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||
|
@ -1,17 +1,16 @@
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import (
|
||||
DetectionModel,
|
||||
ModelOutput,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.models import ModelOutput, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train import TrainExample
|
||||
from batdetect2.train.types import LossProtocol
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
@ -19,28 +18,29 @@ __all__ = [
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
loss: LossProtocol,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
):
|
||||
def __init__(self, config: FullTrainingConfig):
|
||||
super().__init__()
|
||||
|
||||
self.loss = loss
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.targets = targets
|
||||
self.postprocessor = postprocessor
|
||||
# NOTE: Need to convert to vanilla python object so that DVCLive can
|
||||
# store it.
|
||||
self._config = (
|
||||
config.model_dump() if isinstance(config, BaseModel) else config
|
||||
)
|
||||
self.save_hyperparameters({"config": self._config})
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
self.save_hyperparameters()
|
||||
self.config = FullTrainingConfig.model_validate(self._config)
|
||||
self.loss = build_loss(self.config.train.loss)
|
||||
self.targets = build_targets(self.config.targets)
|
||||
self.detector = build_model(
|
||||
num_classes=len(self.targets.class_names),
|
||||
config=self.config.model,
|
||||
)
|
||||
self.preprocessor = build_preprocessor(self.config.preprocess)
|
||||
self.postprocessor = build_postprocessor(
|
||||
self.targets,
|
||||
min_freq=self.preprocessor.min_freq,
|
||||
max_freq=self.preprocessor.max_freq,
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
@ -70,6 +70,6 @@ class TrainingModule(L.LightningModule):
|
||||
return outputs
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
103
src/batdetect2/train/logging.py
Normal file
103
src/batdetect2/train/logging.py
Normal file
@ -0,0 +1,103 @@
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: str = "logs"
|
||||
|
||||
|
||||
class DVCLiveConfig(BaseConfig):
|
||||
logger_type: Literal["dvclive"] = "dvclive"
|
||||
dir: str = DEFAULT_LOGS_DIR
|
||||
run_name: Optional[str] = None
|
||||
prefix: str = ""
|
||||
log_model: Union[bool, Literal["all"]] = False
|
||||
monitor_system: bool = False
|
||||
|
||||
|
||||
class CSVLoggerConfig(BaseConfig):
|
||||
logger_type: Literal["csv"] = "csv"
|
||||
save_dir: str = DEFAULT_LOGS_DIR
|
||||
name: Optional[str] = "logs"
|
||||
version: Optional[str] = None
|
||||
flush_logs_every_n_steps: int = 100
|
||||
|
||||
|
||||
class TensorBoardLoggerConfig(BaseConfig):
|
||||
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||
save_dir: str = DEFAULT_LOGS_DIR
|
||||
name: Optional[str] = "default"
|
||||
version: Optional[str] = None
|
||||
log_graph: bool = False
|
||||
flush_logs_every_n_steps: Optional[int] = None
|
||||
|
||||
|
||||
LoggerConfig = Annotated[
|
||||
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
|
||||
Field(discriminator="logger_type"),
|
||||
]
|
||||
|
||||
|
||||
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||
try:
|
||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"DVCLive is not installed and cannot be used for logging"
|
||||
"Make sure you have it installed by running `pip install dvclive`"
|
||||
"or `uv add dvclive`"
|
||||
) from error
|
||||
|
||||
return DVCLiveLogger(
|
||||
dir=config.dir,
|
||||
run_name=config.run_name,
|
||||
prefix=config.prefix,
|
||||
log_model=config.log_model,
|
||||
monitor_system=config.monitor_system,
|
||||
)
|
||||
|
||||
|
||||
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
return CSVLogger(
|
||||
save_dir=config.save_dir,
|
||||
name=config.name,
|
||||
version=config.version,
|
||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||
)
|
||||
|
||||
|
||||
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
return TensorBoardLogger(
|
||||
save_dir=config.save_dir,
|
||||
name=config.name,
|
||||
version=config.version,
|
||||
log_graph=config.log_graph,
|
||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||
)
|
||||
|
||||
|
||||
LOGGER_FACTORY = {
|
||||
"dvclive": create_dvclive_logger,
|
||||
"csv": create_csv_logger,
|
||||
"tensorboard": create_tensorboard_logger,
|
||||
}
|
||||
|
||||
|
||||
def build_logger(config: LoggerConfig) -> Logger:
|
||||
"""
|
||||
Creates a logger instance from a validated Pydantic config object.
|
||||
"""
|
||||
logger_type = config.logger_type
|
||||
|
||||
if logger_type not in LOGGER_FACTORY:
|
||||
raise ValueError(f"Unknown logger type: {logger_type}")
|
||||
|
||||
creation_func = LOGGER_FACTORY[logger_type]
|
||||
|
||||
return creation_func(config)
|
@ -27,22 +27,71 @@ from typing import Callable, Optional, Sequence
|
||||
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"preprocess_single_annotation",
|
||||
"generate_train_example",
|
||||
"preprocess_dataset",
|
||||
"TrainPreprocessConfig",
|
||||
"load_train_preprocessing_config",
|
||||
]
|
||||
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
"""Type alias for a function that generates an output filename."""
|
||||
|
||||
|
||||
class TrainPreprocessConfig(BaseConfig):
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
def load_train_preprocessing_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TrainPreprocessConfig:
|
||||
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Dataset,
|
||||
config: TrainPreprocessConfig,
|
||||
output: Path,
|
||||
force: bool = False,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
targets = build_targets(config=config.targets)
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
labeller = build_clip_labeler(targets, config=config.labels)
|
||||
|
||||
if not output.exists():
|
||||
logger.debug("Creating directory {directory}", directory=output)
|
||||
output.mkdir(parents=True)
|
||||
|
||||
preprocess_annotations(
|
||||
dataset,
|
||||
output_dir=output,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
replace=force,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
@ -207,8 +256,17 @@ def preprocess_annotations(
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
logger.info(
|
||||
"Creating output directory: {output_dir}", output_dir=output_dir
|
||||
)
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
logger.info(
|
||||
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
|
||||
num_annotations=len(clip_annotations),
|
||||
max_workers=max_workers or "all available",
|
||||
)
|
||||
|
||||
with Pool(max_workers) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
@ -224,8 +282,10 @@ def preprocess_annotations(
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
desc="Preprocessing annotations",
|
||||
)
|
||||
)
|
||||
logger.info("Finished preprocessing.")
|
||||
|
||||
|
||||
def preprocess_single_annotation(
|
||||
@ -264,11 +324,15 @@ def preprocess_single_annotation(
|
||||
path = output_dir / filename
|
||||
|
||||
if path.is_file() and not replace:
|
||||
logger.debug("Skipping existing file: {path}", path=path)
|
||||
return
|
||||
|
||||
if path.is_file() and replace:
|
||||
logger.debug("Removing existing file: {path}", path=path)
|
||||
path.unlink()
|
||||
|
||||
logger.debug("Processing annotation {uuid}", uuid=clip_annotation.uuid)
|
||||
|
||||
try:
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
@ -277,8 +341,9 @@ def preprocess_single_annotation(
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Failed to process annotation: {uuid}. Error {error}",
|
||||
"Failed to process annotation {uuid} to {path}. Error: {error}",
|
||||
uuid=clip_annotation.uuid,
|
||||
path=path,
|
||||
error=error,
|
||||
)
|
||||
return
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
from lightning import Trainer
|
||||
@ -5,109 +6,74 @@ from lightning.pytorch.callbacks import Callback
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.augmentations import (
|
||||
build_augmentations,
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
from batdetect2.targets import TargetProtocol
|
||||
from batdetect2.train.augmentations import build_augmentations
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
collate_fn,
|
||||
)
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train.logging import build_logger
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"build_val_dataset",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"train",
|
||||
]
|
||||
|
||||
|
||||
def train(
|
||||
detector: DetectionModel,
|
||||
train_examples: List[data.PathLike],
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
val_examples: Optional[List[data.PathLike]] = None,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
callbacks: Optional[List[Callback]] = None,
|
||||
train_examples: Sequence[data.PathLike],
|
||||
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
**trainer_kwargs,
|
||||
) -> None:
|
||||
config = config or TrainingConfig()
|
||||
if model_path is None:
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
):
|
||||
conf = config or FullTrainingConfig()
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
if postprocessor is None:
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
loss = build_loss(config.loss)
|
||||
|
||||
module = TrainingModule(
|
||||
detector=detector,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
learning_rate=config.optimizer.learning_rate,
|
||||
t_max=config.optimizer.t_max,
|
||||
)
|
||||
else:
|
||||
if model_path is not None:
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = TrainingModule(conf)
|
||||
|
||||
train_dataset = build_train_dataset(
|
||||
trainer = build_trainer(conf, targets=module.targets)
|
||||
|
||||
train_dataloader = build_train_loader(
|
||||
train_examples,
|
||||
preprocessor=module.preprocessor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
**config.trainer.model_dump(exclude_none=True),
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
config=conf.train,
|
||||
num_workers=train_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
val_dataloader = None
|
||||
if val_examples:
|
||||
val_dataset = build_val_dataset(
|
||||
val_dataloader = (
|
||||
build_val_loader(
|
||||
val_examples,
|
||||
config=config,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
config=conf.train,
|
||||
num_workers=val_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
if val_examples is not None
|
||||
else None
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
module,
|
||||
@ -116,8 +82,76 @@ def train(
|
||||
)
|
||||
|
||||
|
||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||
return [
|
||||
ValidationMetrics(
|
||||
metrics=[
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(
|
||||
class_names=targets.class_names
|
||||
),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> Trainer:
|
||||
trainer_conf = PLTrainerConfig.model_validate(
|
||||
conf.train.model_dump(mode="python")
|
||||
)
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
val_check_interval=conf.train.val_check_interval,
|
||||
logger=build_logger(conf.train.logger),
|
||||
callbacks=build_trainer_callbacks(targets),
|
||||
)
|
||||
|
||||
|
||||
def build_train_loader(
|
||||
train_examples: Sequence[data.PathLike],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: TrainingConfig,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
train_dataset = build_train_dataset(
|
||||
train_examples,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers or 0,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_val_loader(
|
||||
val_examples: Sequence[data.PathLike],
|
||||
config: TrainingConfig,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
val_dataset = build_val_dataset(
|
||||
val_examples,
|
||||
config=config,
|
||||
)
|
||||
return DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers or 0,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_train_dataset(
|
||||
examples: List[data.PathLike],
|
||||
examples: Sequence[data.PathLike],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
) -> LabeledDataset:
|
||||
@ -126,7 +160,7 @@ def build_train_dataset(
|
||||
clipper = build_clipper(config.cliping, random=True)
|
||||
|
||||
random_example_source = RandomExampleSource(
|
||||
examples,
|
||||
list(examples),
|
||||
clipper=clipper,
|
||||
)
|
||||
|
||||
@ -144,7 +178,7 @@ def build_train_dataset(
|
||||
|
||||
|
||||
def build_val_dataset(
|
||||
examples: List[data.PathLike],
|
||||
examples: Sequence[data.PathLike],
|
||||
config: Optional[TrainingConfig] = None,
|
||||
train: bool = True,
|
||||
) -> LabeledDataset:
|
||||
|
10
tests/test_train/test_config.py
Normal file
10
tests/test_train/test_config.py
Normal file
@ -0,0 +1,10 @@
|
||||
from batdetect2.configs import load_config
|
||||
from batdetect2.train import FullTrainingConfig
|
||||
|
||||
|
||||
def test_example_config_is_valid(example_data_dir):
|
||||
conf = load_config(
|
||||
example_data_dir / "config.yaml",
|
||||
schema=FullTrainingConfig,
|
||||
)
|
||||
assert isinstance(conf, FullTrainingConfig)
|
@ -5,27 +5,11 @@ import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||
|
||||
|
||||
def build_default_module():
|
||||
loss = build_loss()
|
||||
targets = build_targets()
|
||||
detector = build_model(num_classes=len(targets.class_names))
|
||||
preprocessor = build_preprocessor()
|
||||
postprocessor = build_postprocessor(targets)
|
||||
return TrainingModule(
|
||||
detector=detector,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
return TrainingModule(FullTrainingConfig())
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
|
Loading…
Reference in New Issue
Block a user