Compare commits

...

23 Commits

Author SHA1 Message Date
mbsantiago
e8db1d4050 Fix hyperparameter saving 2025-06-26 19:53:19 -06:00
mbsantiago
b396d4908a Fix trainer init 2025-06-26 19:40:27 -06:00
mbsantiago
ce15a0f152 Fix trainer init 2025-06-26 19:35:23 -06:00
mbsantiago
16febed792 Add test that ensures example_config is valid 2025-06-26 19:21:46 -06:00
mbsantiago
d67ae9be05 Update config to remove optimiser level 2025-06-26 18:42:53 -06:00
mbsantiago
15de168a20 Update build trainer 2025-06-26 17:43:56 -06:00
mbsantiago
0c8fae4a72 Instantiate lightnign module from config 2025-06-26 17:39:50 -06:00
mbsantiago
6d57f96c07 update gitignore 2025-06-26 16:19:33 -06:00
mbsantiago
a0622aa9a4 Fix train issues 2025-06-26 16:16:29 -06:00
mbsantiago
587742b41e Change train to use full config 2025-06-26 16:02:41 -06:00
mbsantiago
6d91153a56 use just instead of make 2025-06-26 13:29:23 -06:00
mbsantiago
70c96b6844 Move configs to example_data folder 2025-06-26 13:29:13 -06:00
mbsantiago
22f7d46f46 Improve logging of train preprocessing 2025-06-26 13:08:44 -06:00
mbsantiago
1384c549f7 Create TrainPreprocessConfig 2025-06-26 12:30:16 -06:00
mbsantiago
4b6acd5e6e Add manual logging of hyperparams 2025-06-26 11:59:33 -06:00
mbsantiago
2ac968d65b Test with dvc live 2025-06-26 10:05:01 -06:00
mbsantiago
166dad20bd Rename BlockGroupConfig to LayerGroupConfig 2025-06-26 10:04:42 -06:00
mbsantiago
cbb02cf69e Add dvclive as an optional group 2025-06-26 10:04:16 -06:00
mbsantiago
647468123e Update model config in example config 2025-06-26 10:04:00 -06:00
mbsantiago
363eb9fb2f Add example-train to makefile 2025-06-26 10:03:52 -06:00
mbsantiago
b21c224985 Add example preprocess to make file 2025-06-26 07:55:40 -06:00
mbsantiago
152a577511 Add comprehensive conf file 2025-06-26 07:55:33 -06:00
mbsantiago
136949c4e7 Add logging config 2025-06-26 07:55:24 -06:00
25 changed files with 745 additions and 567 deletions

3
.dvc/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/config.local
/tmp
/cache

0
.dvc/config Normal file
View File

3
.dvcignore Normal file
View File

@ -0,0 +1,3 @@
# Add patterns of files dvc should ignore, which could improve
# the performance. Learn more at
# https://dvc.org/doc/user-guide/dvcignore

2
.gitignore vendored
View File

@ -114,3 +114,5 @@ experiments/*
notebooks/lightning_logs notebooks/lightning_logs
example_data/preprocessed example_data/preprocessed
.aider* .aider*
DvcLiveLogger/checkpoints
logs

View File

@ -1,94 +0,0 @@
# Variables
SOURCE_DIR = src
TESTS_DIR = tests
PYTHON_DIRS = src tests
DOCS_SOURCE = docs/source
DOCS_BUILD = docs/build
HTML_COVERAGE_DIR = htmlcov
# Default target (optional, often 'help' or 'all')
.DEFAULT_GOAL := help
# Phony targets (targets that don't produce a file with the same name)
.PHONY: help test coverage coverage-html coverage-serve docs docs-serve format format-check lint lint-fix typecheck check clean clean-pyc clean-test clean-docs clean-build
help:
@echo "Makefile Targets:"
@echo " help Show this help message."
@echo " test Run tests using pytest."
@echo " coverage Run tests and generate coverage data (.coverage, coverage.xml)."
@echo " coverage-html Generate an HTML coverage report in $(HTML_COVERAGE_DIR)/."
@echo " coverage-serve Serve the HTML coverage report locally."
@echo " docs Build documentation using Sphinx."
@echo " docs-serve Serve documentation with live reload using sphinx-autobuild."
@echo " format Format code using ruff."
@echo " format-check Check code formatting using ruff."
@echo " lint Lint code using ruff."
@echo " lint-fix Lint code using ruff and apply automatic fixes."
@echo " typecheck Type check code using pyright."
@echo " check Run all checks (format-check, lint, typecheck)."
@echo " clean Remove all build, test, documentation, and Python artifacts."
@echo " clean-pyc Remove Python bytecode and cache."
@echo " clean-test Remove test and coverage artifacts."
@echo " clean-docs Remove built documentation."
@echo " clean-build Remove package build artifacts."
# Testing & Coverage
test:
pytest tests
coverage:
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml tests
coverage-html: coverage
@echo "Generating HTML coverage report..."
coverage html -d $(HTML_COVERAGE_DIR)
@echo "HTML coverage report generated in $(HTML_COVERAGE_DIR)/"
coverage-serve: coverage-html
@echo "Serving report at http://localhost:8000/ ..."
python -m http.server --directory $(HTML_COVERAGE_DIR) 8000
# Documentation
docs:
sphinx-build -b html $(DOCS_SOURCE) $(DOCS_BUILD)
docs-serve:
sphinx-autobuild $(DOCS_SOURCE) $(DOCS_BUILD) --watch $(SOURCE_DIR) --open-browser
# Formatting & Linting
format:
ruff format $(PYTHON_DIRS)
format-check:
ruff format --check $(PYTHON_DIRS)
lint:
ruff check $(PYTHON_DIRS)
lint-fix:
ruff check --fix $(PYTHON_DIRS)
# Type Checking
typecheck:
pyright $(PYTHON_DIRS)
# Combined Checks
check: format-check lint typecheck test
# Cleaning tasks
clean-pyc:
find . -type f -name "*.py[co]" -delete
find . -type d -name "__pycache__" -exec rm -rf {} +
clean-test:
rm -f .coverage coverage.xml
rm -rf .pytest_cache htmlcov/
clean-docs:
rm -rf $(DOCS_BUILD)
clean-build:
rm -rf build/ dist/ *.egg-info/
clean: clean-build clean-pyc clean-test clean-docs

146
example_data/config.yaml Normal file
View File

@ -0,0 +1,146 @@
targets:
classes:
classes:
- name: myomys
tags:
- value: Myotis mystacinus
- name: pippip
tags:
- value: Pipistrellus pipistrellus
- name: eptser
tags:
- value: Eptesicus serotinus
- name: rhifer
tags:
- value: Rhinolophus ferrumequinum
generic_class:
- key: class
value: Bat
filtering:
rules:
- match_type: all
tags:
- key: event
value: Echolocation
- match_type: exclude
tags:
- key: class
value: Unknown
preprocess:
audio:
resample:
samplerate: 256000
method: "poly"
scale: false
center: true
duration: null
spectrogram:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
pcen:
time_constant: 0.4
gain: 0.98
bias: 2
power: 0.5
scale: "amplitude"
size:
height: 128
resize_factor: 0.5
spectral_mean_substraction: true
peak_normalize: false
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
min_freq: 10000
max_freq: 120000
top_k_per_sec: 200
labels:
sigma: 3
model:
input_height: 128
in_channels: 1
out_channels: 32
encoder:
layers:
- block_type: FreqCoordConvDown
out_channels: 32
- block_type: FreqCoordConvDown
out_channels: 64
- block_type: LayerGroup
layers:
- block_type: FreqCoordConvDown
out_channels: 128
- block_type: ConvBlock
out_channels: 256
bottleneck:
channels: 256
self_attention: true
decoder:
layers:
- block_type: FreqCoordConvUp
out_channels: 64
- block_type: FreqCoordConvUp
out_channels: 32
- block_type: LayerGroup
layers:
- block_type: FreqCoordConvUp
out_channels: 32
- block_type: ConvBlock
out_channels: 32
train:
batch_size: 8
learning_rate: 0.001
t_max: 100
loss:
detection:
weight: 1.0
focal:
beta: 4
alpha: 2
classification:
weight: 2.0
focal:
beta: 4
alpha: 2
size:
weight: 0.1
logger:
logger_type: dvclive
augmentations:
steps:
- augmentation_type: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- augmentation_type: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
- augmentation_type: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- augmentation_type: warp
probability: 0.2
delta: 0.04
- augmentation_type: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- augmentation_type: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3

110
justfile Normal file
View File

@ -0,0 +1,110 @@
# Default command, runs if no recipe is specified.
default:
just --list
# Variables
SOURCE_DIR := "src"
TESTS_DIR := "tests"
PYTHON_DIRS := "src tests"
DOCS_SOURCE := "docs/source"
DOCS_BUILD := "docs/build"
HTML_COVERAGE_DIR := "htmlcov"
# Show available commands
help:
@just --list
# Testing & Coverage
# Run tests using pytest.
test:
pytest {{TESTS_DIR}}
# Run tests and generate coverage data.
coverage:
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
# Generate an HTML coverage report.
coverage-html: coverage
@echo "Generating HTML coverage report..."
coverage html -d {{HTML_COVERAGE_DIR}}
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
# Serve the HTML coverage report locally.
coverage-serve: coverage-html
@echo "Serving report at http://localhost:8000/ ..."
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
# Documentation
# Build documentation using Sphinx.
docs:
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
# Serve documentation with live reload.
docs-serve:
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
# Formatting & Linting
# Format code using ruff.
format:
ruff format {{PYTHON_DIRS}}
# Check code formatting using ruff.
format-check:
ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
lint:
ruff check {{PYTHON_DIRS}}
# Lint code using ruff and apply automatic fixes.
lint-fix:
ruff check --fix {{PYTHON_DIRS}}
# Type Checking
# Type check code using pyright.
typecheck:
pyright {{PYTHON_DIRS}}
# Combined Checks
# Run all checks (format-check, lint, typecheck).
check: format-check lint typecheck test
# Cleaning tasks
# Remove Python bytecode and cache.
clean-pyc:
find . -type f -name "*.py[co]" -delete
find . -type d -name "__pycache__" -exec rm -rf {} +
# Remove test and coverage artifacts.
clean-test:
rm -f .coverage coverage.xml
rm -rf .pytest_cache htmlcov/
# Remove built documentation.
clean-docs:
rm -rf {{DOCS_BUILD}}
# Remove package build artifacts.
clean-build:
rm -rf build/ dist/ *.egg-info/
# Remove all build, test, documentation, and Python artifacts.
clean: clean-build clean-pyc clean-test clean-docs
# Examples
# Preprocess example data.
example-preprocess OPTIONS="":
batdetect2 preprocess \
--base-dir . \
--dataset-field datasets.train \
--config example_data/config.yaml \
{{OPTIONS}} \
example_data/datasets.yaml example_data/preprocessed
# Train on example data.
example-train OPTIONS="":
batdetect2 train \
--val-dir example_data/preprocessed \
--config example_data/config.yaml \
{{OPTIONS}} \
example_data/preprocessed

View File

@ -85,6 +85,11 @@ dev = [
"sphinx-book-theme>=1.1.4", "sphinx-book-theme>=1.1.4",
"autodoc-pydantic>=2.2.0", "autodoc-pydantic>=2.2.0",
"pytest-cov>=6.1.1", "pytest-cov>=6.1.1",
"ty>=0.0.1a12",
"rust-just>=1.40.0",
]
dvclive = [
"dvclive>=3.48.2",
] ]
[tool.ruff] [tool.ruff]

View File

@ -2,13 +2,13 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect from batdetect2.cli.compat import detect
from batdetect2.cli.data import data from batdetect2.cli.data import data
from batdetect2.cli.preprocess import preprocess from batdetect2.cli.preprocess import preprocess
from batdetect2.cli.train import train from batdetect2.cli.train import train_command
__all__ = [ __all__ = [
"cli", "cli",
"detect", "detect",
"data", "data",
"train", "train_command",
"preprocess", "preprocess",
] ]

View File

@ -1,15 +1,18 @@
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import click import click
import yaml
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.train.preprocess import (
from batdetect2.targets import build_targets, load_target_config TrainPreprocessConfig,
from batdetect2.train import load_label_config, preprocess_annotations load_train_preprocessing_config,
from batdetect2.train.labels import build_clip_labeler preprocess_dataset,
)
__all__ = ["preprocess"] __all__ = ["preprocess"]
@ -44,16 +47,16 @@ __all__ = ["preprocess"]
), ),
) )
@click.option( @click.option(
"--preprocess-config", "--config",
type=click.Path(exists=True), type=click.Path(exists=True),
help=( help=(
"Path to the preprocessing configuration file. This file tells " "Path to the configuration file. This file tells "
"the program how to prepare your audio data before training, such " "the program how to prepare your audio data before training, such "
"as resampling or applying filters." "as resampling or applying filters."
), ),
) )
@click.option( @click.option(
"--preprocess-config-field", "--config-field",
type=str, type=str,
help=( help=(
"If the preprocessing settings are inside a nested dictionary " "If the preprocessing settings are inside a nested dictionary "
@ -62,41 +65,6 @@ __all__ = ["preprocess"]
"top level, you don't need to specify this." "top level, you don't need to specify this."
), ),
) )
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
)
@click.option( @click.option(
"--force", "--force",
is_flag=True, is_flag=True,
@ -117,20 +85,32 @@ __all__ = ["preprocess"]
"the program will use all available cores." "the program will use all available cores."
), ),
) )
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def preprocess( def preprocess(
dataset_config: Path, dataset_config: Path,
output: Path, output: Path,
target_config: Optional[Path] = None,
base_dir: Optional[Path] = None, base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None, config: Optional[Path] = None,
label_config: Optional[Path] = None, config_field: Optional[str] = None,
force: bool = False, force: bool = False,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
target_config_field: Optional[str] = None,
preprocess_config_field: Optional[str] = None,
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None, dataset_field: Optional[str] = None,
verbose: int = 0,
): ):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Starting preprocessing.") logger.info("Starting preprocessing.")
output = Path(output) output = Path(output)
@ -139,31 +119,19 @@ def preprocess(
base_dir = base_dir or Path.cwd() base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir) logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = ( if config:
load_preprocessing_config( logger.info(
preprocess_config, "Loading preprocessing config from: {config}", config=config
field=preprocess_config_field,
) )
if preprocess_config
else None
)
target = ( conf = (
load_target_config( load_train_preprocessing_config(config, field=config_field)
target_config, if config is not None
field=target_config_field, else TrainPreprocessConfig()
)
if target_config
else None
) )
logger.debug(
label = ( "Preprocessing config:\n{conf}",
load_label_config( conf=yaml.dump(conf.model_dump()),
label_config,
field=label_config_field,
)
if label_config
else None
) )
dataset = load_dataset_from_config( dataset = load_dataset_from_config(
@ -177,20 +145,10 @@ def preprocess(
num_examples=len(dataset), num_examples=len(dataset),
) )
targets = build_targets(config=target) preprocess_dataset(
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
dataset, dataset,
output_dir=output, conf,
preprocessor=preprocessor, output=output,
labeller=labeller, force=force,
replace=force,
max_workers=num_workers, max_workers=num_workers,
) )

View File

@ -5,236 +5,53 @@ import click
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.evaluate.metrics import ( from batdetect2.train import (
ClassificationAccuracy, FullTrainingConfig,
ClassificationMeanAveragePrecision, load_full_training_config,
DetectionAveragePrecision, train,
) )
from batdetect2.models import build_model
from batdetect2.models.backbones import load_backbone_config
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import train
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import list_preprocessed_files from batdetect2.train.dataset import list_preprocessed_files
__all__ = [ __all__ = [
"train_command", "train_command",
] ]
DEFAULT_CONFIG_FILE = Path("config.yaml")
@cli.command(name="train") @cli.command(name="train")
@click.option( @click.argument("train_dir", type=click.Path(exists=True))
"--train-examples", @click.option("--val-dir", type=click.Path(exists=True))
type=click.Path(exists=True), @click.option("--model-path", type=click.Path(exists=True))
required=True, @click.option("--config", type=click.Path(exists=True))
) @click.option("--config-field", type=str)
@click.option("--val-examples", type=click.Path(exists=True)) @click.option("--train-workers", type=int, default=0)
@click.option( @click.option("--val-workers", type=int, default=0)
"--model-path",
type=click.Path(exists=True),
)
@click.option(
"--train-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--train-config-field",
type=str,
default="train",
)
@click.option(
"--preprocess-config",
type=click.Path(exists=True),
help=(
"Path to the preprocessing configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--preprocess-config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
default="preprocess",
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
default="targets",
)
@click.option(
"--postprocess-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--postprocess-config-field",
type=str,
default="postprocess",
)
@click.option(
"--model-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--model-config-field",
type=str,
default="model",
)
@click.option(
"--train-workers",
type=int,
default=0,
)
@click.option(
"--val-workers",
type=int,
default=0,
)
def train_command( def train_command(
train_examples: Path, train_dir: Path,
val_examples: Optional[Path] = None, val_dir: Optional[Path] = None,
model_path: Optional[Path] = None, model_path: Optional[Path] = None,
train_config: Path = DEFAULT_CONFIG_FILE, config: Optional[Path] = None,
train_config_field: str = "train", config_field: Optional[str] = None,
preprocess_config: Path = DEFAULT_CONFIG_FILE,
preprocess_config_field: str = "preprocess",
target_config: Path = DEFAULT_CONFIG_FILE,
target_config_field: str = "targets",
postprocess_config: Path = DEFAULT_CONFIG_FILE,
postprocess_config_field: str = "postprocess",
model_config: Path = DEFAULT_CONFIG_FILE,
model_config_field: str = "model",
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
): ):
logger.info("Starting training!") logger.info("Starting training!")
try: conf = (
target_config_loaded = load_target_config( load_full_training_config(config, field=config_field)
path=target_config, if config is not None
field=target_config_field, else FullTrainingConfig()
)
targets = build_targets(config=target_config_loaded)
logger.debug(
"Loaded targets info from config file {path}", path=target_config
)
except IOError:
logger.debug(
"Could not load target info from config file, using default"
)
targets = build_targets()
try:
preprocess_config_loaded = load_preprocessing_config(
path=preprocess_config,
field=preprocess_config_field,
)
preprocessor = build_preprocessor(preprocess_config_loaded)
logger.debug(
"Loaded preprocessor from config file {path}", path=target_config
)
except IOError:
logger.debug(
"Could not load preprocessor from config file, using default"
)
preprocessor = build_preprocessor()
try:
model_config_loaded = load_backbone_config(
path=model_config, field=model_config_field
)
model = build_model(
num_classes=len(targets.class_names),
config=model_config_loaded,
)
except IOError:
model = build_model(num_classes=len(targets.class_names))
try:
postprocess_config_loaded = load_postprocess_config(
path=postprocess_config,
field=postprocess_config_field,
)
postprocessor = build_postprocessor(
targets=targets,
config=postprocess_config_loaded,
)
logger.debug(
"Loaded postprocessor from file {path}", path=postprocess_config
)
except IOError:
logger.debug(
"Could not load postprocessor config from file. Using default"
)
postprocessor = build_postprocessor(targets=targets)
try:
train_config_loaded = load_train_config(
path=train_config, field=train_config_field
)
logger.debug(
"Loaded training config from file {path}",
path=train_config,
)
except IOError:
train_config_loaded = TrainingConfig()
logger.debug("Could not load training config from file. Using default")
train_files = list_preprocessed_files(train_examples)
val_files = (
None if val_examples is None else list_preprocessed_files(val_examples)
) )
return train( train_examples = list_preprocessed_files(train_dir)
detector=model, val_examples = (
train_examples=train_files, # type: ignore list_preprocessed_files(val_dir) if val_dir is not None else None
val_examples=val_files, # type: ignore )
train(
train_examples=train_examples,
val_examples=val_examples,
config=conf,
model_path=model_path, model_path=model_path,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=targets,
config=train_config_loaded,
callbacks=[
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names,
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
],
train_workers=train_workers, train_workers=train_workers,
val_workers=val_workers, val_workers=val_workers,
) )

View File

@ -25,20 +25,9 @@ class BaseConfig(BaseModel):
Inherits from Pydantic's `BaseModel` to provide data validation, parsing, Inherits from Pydantic's `BaseModel` to provide data validation, parsing,
and serialization capabilities. and serialization capabilities.
It sets `extra='forbid'` in its model configuration, meaning that any
fields provided in a configuration file that are *not* explicitly defined
in the specific configuration schema will raise a validation error. This
helps catch typos and ensures configurations strictly adhere to the expected
structure.
Attributes
----------
model_config : ConfigDict
Pydantic model configuration dictionary. Set to forbid extra fields.
""" """
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="ignore")
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)

View File

@ -27,9 +27,23 @@ from torch import nn
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.blocks import ConvBlock from batdetect2.models.blocks import ConvBlock
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck from batdetect2.models.bottleneck import (
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder DEFAULT_BOTTLENECK_CONFIG,
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
Decoder,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
Encoder,
EncoderConfig,
build_encoder,
)
from batdetect2.models.types import BackboneModel from batdetect2.models.types import BackboneModel
__all__ = [ __all__ = [
@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig):
input_height: int = 128 input_height: int = 128
in_channels: int = 1 in_channels: int = 1
encoder: Optional[EncoderConfig] = None encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: Optional[BottleneckConfig] = None bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: Optional[DecoderConfig] = None decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32 out_channels: int = 32

View File

@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig
__all__ = [ __all__ = [
"ConvBlock", "ConvBlock",
"BlockGroupConfig", "LayerGroupConfig",
"VerticalConv", "VerticalConv",
"FreqCoordConvDownBlock", "FreqCoordConvDownBlock",
"StandardConvDownBlock", "StandardConvDownBlock",
@ -654,16 +654,16 @@ LayerConfig = Annotated[
StandardConvDownConfig, StandardConvDownConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
"BlockGroupConfig", "LayerGroupConfig",
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
"""Type alias for the discriminated union of block configuration models.""" """Type alias for the discriminated union of block configuration models."""
class BlockGroupConfig(BaseConfig): class LayerGroupConfig(BaseConfig):
block_type: Literal["group"] = "group" block_type: Literal["LayerGroup"] = "LayerGroup"
blocks: List[LayerConfig] layers: List[LayerConfig]
def build_layer_from_config( def build_layer_from_config(
@ -769,13 +769,13 @@ def build_layer_from_config(
input_height * 2, input_height * 2,
) )
if config.block_type == "group": if config.block_type == "LayerGroup":
current_channels = in_channels current_channels = in_channels
current_height = input_height current_height = input_height
blocks = [] blocks = []
for block_config in config.blocks: for block_config in config.layers:
block, current_channels, current_height = build_layer_from_config( block, current_channels, current_height = build_layer_from_config(
input_height=current_height, input_height=current_height,
in_channels=current_channels, in_channels=current_channels,

View File

@ -26,7 +26,7 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
BlockGroupConfig, LayerGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[
ConvConfig, ConvConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
BlockGroupConfig, LayerGroupConfig,
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
layers=[ layers=[
FreqCoordConvUpConfig(out_channels=64), FreqCoordConvUpConfig(out_channels=64),
FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpConfig(out_channels=32),
BlockGroupConfig( LayerGroupConfig(
blocks=[ layers=[
FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpConfig(out_channels=32),
ConvConfig(out_channels=32), ConvConfig(out_channels=32),
] ]

View File

@ -28,7 +28,7 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
BlockGroupConfig, LayerGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
StandardConvDownConfig, StandardConvDownConfig,
@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
StandardConvDownConfig, StandardConvDownConfig,
BlockGroupConfig, LayerGroupConfig,
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
layers=[ layers=[
FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownConfig(out_channels=32),
FreqCoordConvDownConfig(out_channels=64), FreqCoordConvDownConfig(out_channels=64),
BlockGroupConfig( LayerGroupConfig(
blocks=[ layers=[
FreqCoordConvDownConfig(out_channels=128), FreqCoordConvDownConfig(out_channels=128),
ConvConfig(out_channels=256), ConvConfig(out_channels=256),
] ]

View File

@ -15,8 +15,10 @@ from batdetect2.train.augmentations import (
) )
from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import ( from batdetect2.train.config import (
TrainerConfig, FullTrainingConfig,
PLTrainerConfig,
TrainingConfig, TrainingConfig,
load_full_training_config,
load_train_config, load_train_config,
) )
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
@ -26,6 +28,7 @@ from batdetect2.train.dataset import (
list_preprocessed_files, list_preprocessed_files,
) )
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import ( from batdetect2.train.losses import (
ClassificationLossConfig, ClassificationLossConfig,
DetectionLossConfig, DetectionLossConfig,
@ -40,7 +43,10 @@ from batdetect2.train.preprocess import (
) )
from batdetect2.train.train import ( from batdetect2.train.train import (
build_train_dataset, build_train_dataset,
build_train_loader,
build_trainer,
build_val_dataset, build_val_dataset,
build_val_loader,
train, train,
) )
@ -50,15 +56,17 @@ __all__ = [
"DetectionLossConfig", "DetectionLossConfig",
"EchoAugmentationConfig", "EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig", "FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"LabeledDataset", "LabeledDataset",
"LossConfig", "LossConfig",
"LossFunction", "LossFunction",
"PLTrainerConfig",
"RandomExampleSource", "RandomExampleSource",
"SizeLossConfig", "SizeLossConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainExample", "TrainExample",
"TrainerConfig",
"TrainingConfig", "TrainingConfig",
"TrainingModule",
"VolumeAugmentationConfig", "VolumeAugmentationConfig",
"WarpAugmentationConfig", "WarpAugmentationConfig",
"add_echo", "add_echo",
@ -67,9 +75,13 @@ __all__ = [
"build_clipper", "build_clipper",
"build_loss", "build_loss",
"build_train_dataset", "build_train_dataset",
"build_train_loader",
"build_trainer",
"build_val_dataset", "build_val_dataset",
"build_val_loader",
"generate_train_example", "generate_train_example",
"list_preprocessed_files", "list_preprocessed_files",
"load_full_training_config",
"load_label_config", "load_label_config",
"load_train_config", "load_train_config",
"mask_frequency", "mask_frequency",
@ -79,6 +91,5 @@ __all__ = [
"scale_volume", "scale_volume",
"select_subclip", "select_subclip",
"train", "train",
"train",
"warp_spectrogram", "warp_spectrogram",
] ]

View File

@ -4,26 +4,27 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.models import BackboneConfig
from batdetect2.postprocess import PostprocessConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG, DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig, AugmentationsConfig,
) )
from batdetect2.train.clips import ClipingConfig from batdetect2.train.clips import ClipingConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
__all__ = [ __all__ = [
"OptimizerConfig",
"TrainingConfig", "TrainingConfig",
"load_train_config", "load_train_config",
"FullTrainingConfig",
"load_full_training_config",
] ]
class OptimizerConfig(BaseConfig): class PLTrainerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainerConfig(BaseConfig):
accelerator: str = "auto" accelerator: str = "auto"
accumulate_grad_batches: int = 1 accumulate_grad_batches: int = 1
deterministic: bool = True deterministic: bool = True
@ -44,20 +45,17 @@ class TrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None val_check_interval: Optional[Union[int, float]] = None
class TrainingConfig(BaseConfig): class TrainingConfig(PLTrainerConfig):
batch_size: int = 8 batch_size: int = 8
learning_rate: float = 1e-3
t_max: int = 100
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field( augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
) )
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
trainer: TrainerConfig = Field(default_factory=TrainerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def load_train_config( def load_train_config(
@ -65,3 +63,23 @@ def load_train_config(
field: Optional[str] = None, field: Optional[str] = None,
) -> TrainingConfig: ) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field) return load_config(path, schema=TrainingConfig, field=field)
class FullTrainingConfig(BaseConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)

View File

@ -1,17 +1,16 @@
import lightning as L import lightning as L
import torch import torch
from pydantic import BaseModel
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import ( from batdetect2.models import ModelOutput, build_model
DetectionModel, from batdetect2.postprocess import build_postprocessor
ModelOutput, from batdetect2.preprocess import build_preprocessor
) from batdetect2.targets import build_targets
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample from batdetect2.train import TrainExample
from batdetect2.train.types import LossProtocol from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"TrainingModule", "TrainingModule",
@ -19,28 +18,29 @@ __all__ = [
class TrainingModule(L.LightningModule): class TrainingModule(L.LightningModule):
def __init__( def __init__(self, config: FullTrainingConfig):
self,
detector: DetectionModel,
loss: LossProtocol,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
super().__init__() super().__init__()
self.loss = loss # NOTE: Need to convert to vanilla python object so that DVCLive can
self.detector = detector # store it.
self.preprocessor = preprocessor self._config = (
self.targets = targets config.model_dump() if isinstance(config, BaseModel) else config
self.postprocessor = postprocessor )
self.save_hyperparameters({"config": self._config})
self.learning_rate = learning_rate self.config = FullTrainingConfig.model_validate(self._config)
self.t_max = t_max self.loss = build_loss(self.config.train.loss)
self.targets = build_targets(self.config.targets)
self.save_hyperparameters() self.detector = build_model(
num_classes=len(self.targets.class_names),
config=self.config.model,
)
self.preprocessor = build_preprocessor(self.config.preprocess)
self.postprocessor = build_postprocessor(
self.targets,
min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq,
)
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec) return self.detector(spec)
@ -70,6 +70,6 @@ class TrainingModule(L.LightningModule):
return outputs return outputs
def configure_optimizers(self): def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate) optimizer = Adam(self.parameters(), lr=self.config.train.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) scheduler = CosineAnnealingLR(optimizer, T_max=self.config.train.t_max)
return [optimizer], [scheduler] return [optimizer], [scheduler]

View File

@ -0,0 +1,103 @@
from typing import Annotated, Literal, Optional, Union
from lightning.pytorch.loggers import Logger
from pydantic import Field
from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "logs"
class DVCLiveConfig(BaseConfig):
logger_type: Literal["dvclive"] = "dvclive"
dir: str = DEFAULT_LOGS_DIR
run_name: Optional[str] = None
prefix: str = ""
log_model: Union[bool, Literal["all"]] = False
monitor_system: bool = False
class CSVLoggerConfig(BaseConfig):
logger_type: Literal["csv"] = "csv"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "logs"
version: Optional[str] = None
flush_logs_every_n_steps: int = 100
class TensorBoardLoggerConfig(BaseConfig):
logger_type: Literal["tensorboard"] = "tensorboard"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "default"
version: Optional[str] = None
log_graph: bool = False
flush_logs_every_n_steps: Optional[int] = None
LoggerConfig = Annotated[
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
Field(discriminator="logger_type"),
]
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
try:
from dvclive.lightning import DVCLiveLogger # type: ignore
except ImportError as error:
raise ValueError(
"DVCLive is not installed and cannot be used for logging"
"Make sure you have it installed by running `pip install dvclive`"
"or `uv add dvclive`"
) from error
return DVCLiveLogger(
dir=config.dir,
run_name=config.run_name,
prefix=config.prefix,
log_model=config.log_model,
monitor_system=config.monitor_system,
)
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
from lightning.pytorch.loggers import CSVLogger
return CSVLogger(
save_dir=config.save_dir,
name=config.name,
version=config.version,
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
)
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger
return TensorBoardLogger(
save_dir=config.save_dir,
name=config.name,
version=config.version,
log_graph=config.log_graph,
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
)
LOGGER_FACTORY = {
"dvclive": create_dvclive_logger,
"csv": create_csv_logger,
"tensorboard": create_tensorboard_logger,
}
def build_logger(config: LoggerConfig) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
logger_type = config.logger_type
if logger_type not in LOGGER_FACTORY:
raise ValueError(f"Unknown logger type: {logger_type}")
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config)

View File

@ -27,22 +27,71 @@ from typing import Callable, Optional, Sequence
import xarray as xr import xarray as xr
from loguru import logger from loguru import logger
from pydantic import Field
from soundevent import data from soundevent import data
from tqdm.auto import tqdm from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.train.types import ClipLabeller from batdetect2.train.types import ClipLabeller
__all__ = [ __all__ = [
"preprocess_annotations", "preprocess_annotations",
"preprocess_single_annotation", "preprocess_single_annotation",
"generate_train_example", "generate_train_example",
"preprocess_dataset",
"TrainPreprocessConfig",
"load_train_preprocessing_config",
] ]
FilenameFn = Callable[[data.ClipAnnotation], str] FilenameFn = Callable[[data.ClipAnnotation], str]
"""Type alias for a function that generates an output filename.""" """Type alias for a function that generates an output filename."""
class TrainPreprocessConfig(BaseConfig):
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def load_train_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TrainPreprocessConfig:
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
def preprocess_dataset(
dataset: Dataset,
config: TrainPreprocessConfig,
output: Path,
force: bool = False,
max_workers: Optional[int] = None,
) -> None:
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
labeller = build_clip_labeler(targets, config=config.labels)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
preprocess_annotations(
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
max_workers=max_workers,
)
def generate_train_example( def generate_train_example(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
@ -207,8 +256,17 @@ def preprocess_annotations(
output_dir = Path(output_dir) output_dir = Path(output_dir)
if not output_dir.is_dir(): if not output_dir.is_dir():
logger.info(
"Creating output directory: {output_dir}", output_dir=output_dir
)
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
logger.info(
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
num_annotations=len(clip_annotations),
max_workers=max_workers or "all available",
)
with Pool(max_workers) as pool: with Pool(max_workers) as pool:
list( list(
tqdm( tqdm(
@ -224,8 +282,10 @@ def preprocess_annotations(
clip_annotations, clip_annotations,
), ),
total=len(clip_annotations), total=len(clip_annotations),
desc="Preprocessing annotations",
) )
) )
logger.info("Finished preprocessing.")
def preprocess_single_annotation( def preprocess_single_annotation(
@ -264,11 +324,15 @@ def preprocess_single_annotation(
path = output_dir / filename path = output_dir / filename
if path.is_file() and not replace: if path.is_file() and not replace:
logger.debug("Skipping existing file: {path}", path=path)
return return
if path.is_file() and replace: if path.is_file() and replace:
logger.debug("Removing existing file: {path}", path=path)
path.unlink() path.unlink()
logger.debug("Processing annotation {uuid}", uuid=clip_annotation.uuid)
try: try:
sample = generate_train_example( sample = generate_train_example(
clip_annotation, clip_annotation,
@ -277,8 +341,9 @@ def preprocess_single_annotation(
) )
except Exception as error: except Exception as error:
logger.error( logger.error(
"Failed to process annotation: {uuid}. Error {error}", "Failed to process annotation {uuid} to {path}. Error: {error}",
uuid=clip_annotation.uuid, uuid=clip_annotation.uuid,
path=path,
error=error, error=error,
) )
return return

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import List, Optional from typing import List, Optional
from lightning import Trainer from lightning import Trainer
@ -5,109 +6,74 @@ from lightning.pytorch.callbacks import Callback
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.models.types import DetectionModel from batdetect2.evaluate.metrics import (
from batdetect2.postprocess import build_postprocessor ClassificationAccuracy,
from batdetect2.postprocess.types import PostprocessorProtocol ClassificationMeanAveragePrecision,
from batdetect2.preprocess import build_preprocessor DetectionAveragePrecision,
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.augmentations import (
build_augmentations,
) )
from batdetect2.preprocess import (
PreprocessorProtocol,
)
from batdetect2.targets import TargetProtocol
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
)
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
collate_fn, collate_fn,
) )
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss from batdetect2.train.logging import build_logger
__all__ = [ __all__ = [
"train",
"build_val_dataset",
"build_train_dataset", "build_train_dataset",
"build_train_loader",
"build_trainer",
"build_val_dataset",
"build_val_loader",
"train",
] ]
def train( def train(
detector: DetectionModel, train_examples: Sequence[data.PathLike],
train_examples: List[data.PathLike], val_examples: Optional[Sequence[data.PathLike]] = None,
targets: Optional[TargetProtocol] = None, config: Optional[FullTrainingConfig] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
val_examples: Optional[List[data.PathLike]] = None,
config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None,
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
**trainer_kwargs, ):
) -> None: conf = config or FullTrainingConfig()
config = config or TrainingConfig()
if model_path is None:
if preprocessor is None:
preprocessor = build_preprocessor()
if targets is None: if model_path is not None:
targets = build_targets()
if postprocessor is None:
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
loss = build_loss(config.loss)
module = TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=config.optimizer.learning_rate,
t_max=config.optimizer.t_max,
)
else:
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = TrainingModule(conf)
train_dataset = build_train_dataset( trainer = build_trainer(conf, targets=module.targets)
train_dataloader = build_train_loader(
train_examples, train_examples,
preprocessor=module.preprocessor, preprocessor=module.preprocessor,
config=config, config=conf.train,
)
trainer = Trainer(
**config.trainer.model_dump(exclude_none=True),
callbacks=callbacks,
**trainer_kwargs,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=train_workers, num_workers=train_workers,
collate_fn=collate_fn,
) )
val_dataloader = None val_dataloader = (
if val_examples: build_val_loader(
val_dataset = build_val_dataset(
val_examples, val_examples,
config=config, config=conf.train,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=val_workers, num_workers=val_workers,
collate_fn=collate_fn,
) )
if val_examples is not None
else None
)
trainer.fit( trainer.fit(
module, module,
@ -116,8 +82,76 @@ def train(
) )
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
]
def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
) -> Trainer:
trainer_conf = PLTrainerConfig.model_validate(
conf.train.model_dump(mode="python")
)
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval,
logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets),
)
def build_train_loader(
train_examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol,
config: TrainingConfig,
num_workers: Optional[int] = None,
) -> DataLoader:
train_dataset = build_train_dataset(
train_examples,
preprocessor=preprocessor,
config=config,
)
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers or 0,
collate_fn=collate_fn,
)
def build_val_loader(
val_examples: Sequence[data.PathLike],
config: TrainingConfig,
num_workers: Optional[int] = None,
):
val_dataset = build_val_dataset(
val_examples,
config=config,
)
return DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=num_workers or 0,
collate_fn=collate_fn,
)
def build_train_dataset( def build_train_dataset(
examples: List[data.PathLike], examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
) -> LabeledDataset: ) -> LabeledDataset:
@ -126,7 +160,7 @@ def build_train_dataset(
clipper = build_clipper(config.cliping, random=True) clipper = build_clipper(config.cliping, random=True)
random_example_source = RandomExampleSource( random_example_source = RandomExampleSource(
examples, list(examples),
clipper=clipper, clipper=clipper,
) )
@ -144,7 +178,7 @@ def build_train_dataset(
def build_val_dataset( def build_val_dataset(
examples: List[data.PathLike], examples: Sequence[data.PathLike],
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
train: bool = True, train: bool = True,
) -> LabeledDataset: ) -> LabeledDataset:

View File

@ -0,0 +1,10 @@
from batdetect2.configs import load_config
from batdetect2.train import FullTrainingConfig
def test_example_config_is_valid(example_data_dir):
conf = load_config(
example_data_dir / "config.yaml",
schema=FullTrainingConfig,
)
assert isinstance(conf, FullTrainingConfig)

View File

@ -5,27 +5,11 @@ import torch
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.models import build_model from batdetect2.train import FullTrainingConfig, TrainingModule
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss
def build_default_module(): def build_default_module():
loss = build_loss() return TrainingModule(FullTrainingConfig())
targets = build_targets()
detector = build_model(num_classes=len(targets.class_names))
preprocessor = build_preprocessor()
postprocessor = build_postprocessor(targets)
return TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
def test_can_initialize_default_module(): def test_can_initialize_default_module():