mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
No commits in common. "fdab0860fd9b9c32295baccdb963d9ad058681ea" and "4b4c3ecdf5551fba731bc9c0e00f807ddfd230a5" have entirely different histories.
fdab0860fd
...
4b4c3ecdf5
@ -1,196 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
|
||||||
from batdetect2.train import load_label_config, preprocess_annotations
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
|
||||||
|
|
||||||
__all__ = ["preprocess"]
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
@click.argument(
|
|
||||||
"dataset_config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
)
|
|
||||||
@click.argument(
|
|
||||||
"output",
|
|
||||||
type=click.Path(),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--dataset-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"Specifies the key to access the dataset information within the "
|
|
||||||
"dataset configuration file, if the information is nested inside a "
|
|
||||||
"dictionary. If the dataset information is at the top level of the "
|
|
||||||
"config file, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--base-dir",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"The main directory where your audio recordings and annotation "
|
|
||||||
"files are stored. This helps the program find your data, "
|
|
||||||
"especially if the paths in your dataset configuration file "
|
|
||||||
"are relative."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--preprocess-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the preprocessing configuration file. This file tells "
|
|
||||||
"the program how to prepare your audio data before training, such "
|
|
||||||
"as resampling or applying filters."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--preprocess-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the preprocessing settings are inside a nested dictionary "
|
|
||||||
"within the preprocessing configuration file, specify the key "
|
|
||||||
"here to access them. If the preprocessing settings are at the "
|
|
||||||
"top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--label-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the label generation configuration file. This file "
|
|
||||||
"contains settings for how to create labels from your "
|
|
||||||
"annotations, which the model uses to learn."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--label-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the label generation settings are inside a nested dictionary "
|
|
||||||
"within the label configuration file, specify the key here. If "
|
|
||||||
"the settings are at the top level, leave this blank."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the training target configuration file. This file "
|
|
||||||
"specifies what sounds the model should learn to predict."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the target settings are inside a nested dictionary "
|
|
||||||
"within the target configuration file, specify the key here. "
|
|
||||||
"If the settings are at the top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--force",
|
|
||||||
is_flag=True,
|
|
||||||
help=(
|
|
||||||
"If a preprocessed file already exists, this option tells the "
|
|
||||||
"program to overwrite it with the new preprocessed data. Use "
|
|
||||||
"this if you want to re-do the preprocessing even if the files "
|
|
||||||
"already exist."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
help=(
|
|
||||||
"The maximum number of computer cores to use when processing "
|
|
||||||
"your audio data. Using more cores can speed up the preprocessing, "
|
|
||||||
"but don't use more than your computer has available. By default, "
|
|
||||||
"the program will use all available cores."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def preprocess(
|
|
||||||
dataset_config: Path,
|
|
||||||
output: Path,
|
|
||||||
target_config: Optional[Path] = None,
|
|
||||||
base_dir: Optional[Path] = None,
|
|
||||||
preprocess_config: Optional[Path] = None,
|
|
||||||
label_config: Optional[Path] = None,
|
|
||||||
force: bool = False,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
target_config_field: Optional[str] = None,
|
|
||||||
preprocess_config_field: Optional[str] = None,
|
|
||||||
label_config_field: Optional[str] = None,
|
|
||||||
dataset_field: Optional[str] = None,
|
|
||||||
):
|
|
||||||
logger.info("Starting preprocessing.")
|
|
||||||
|
|
||||||
output = Path(output)
|
|
||||||
logger.info("Will save outputs to {output}", output=output)
|
|
||||||
|
|
||||||
base_dir = base_dir or Path.cwd()
|
|
||||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
|
||||||
|
|
||||||
preprocess = (
|
|
||||||
load_preprocessing_config(
|
|
||||||
preprocess_config,
|
|
||||||
field=preprocess_config_field,
|
|
||||||
)
|
|
||||||
if preprocess_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
target = (
|
|
||||||
load_target_config(
|
|
||||||
target_config,
|
|
||||||
field=target_config_field,
|
|
||||||
)
|
|
||||||
if target_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
label = (
|
|
||||||
load_label_config(
|
|
||||||
label_config,
|
|
||||||
field=label_config_field,
|
|
||||||
)
|
|
||||||
if label_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = load_dataset_from_config(
|
|
||||||
dataset_config,
|
|
||||||
field=dataset_field,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Loaded {num_examples} annotated clips from the configured dataset",
|
|
||||||
num_examples=len(dataset),
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = build_targets(config=target)
|
|
||||||
preprocessor = build_preprocessor(config=preprocess)
|
|
||||||
labeller = build_clip_labeler(targets, config=label)
|
|
||||||
|
|
||||||
if not output.exists():
|
|
||||||
logger.debug("Creating directory {directory}", directory=output)
|
|
||||||
output.mkdir(parents=True)
|
|
||||||
|
|
||||||
logger.info("Will start preprocessing")
|
|
||||||
preprocess_annotations(
|
|
||||||
dataset,
|
|
||||||
output_dir=output,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
labeller=labeller,
|
|
||||||
replace=force,
|
|
||||||
max_workers=num_workers,
|
|
||||||
)
|
|
@ -5,48 +5,47 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.data import load_dataset_from_config
|
||||||
ClassificationAccuracy,
|
|
||||||
ClassificationMeanAveragePrecision,
|
|
||||||
DetectionAveragePrecision,
|
|
||||||
)
|
|
||||||
from batdetect2.models import build_model
|
|
||||||
from batdetect2.models.backbones import load_backbone_config
|
|
||||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
from batdetect2.train import train
|
from batdetect2.train import load_label_config, preprocess_annotations
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
|
||||||
from batdetect2.train.dataset import list_preprocessed_files
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["train"]
|
||||||
"train_command",
|
|
||||||
]
|
|
||||||
|
|
||||||
DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="train")
|
@cli.group()
|
||||||
@click.option(
|
def train(): ...
|
||||||
"--train-examples",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
required=True,
|
@train.command()
|
||||||
)
|
@click.argument(
|
||||||
@click.option("--val-examples", type=click.Path(exists=True))
|
"dataset_config",
|
||||||
@click.option(
|
|
||||||
"--model-path",
|
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.argument(
|
||||||
"--train-config",
|
"output",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(),
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--train-config-field",
|
"--dataset-field",
|
||||||
type=str,
|
type=str,
|
||||||
default="train",
|
help=(
|
||||||
|
"Specifies the key to access the dataset information within the "
|
||||||
|
"dataset configuration file, if the information is nested inside a "
|
||||||
|
"dictionary. If the dataset information is at the top level of the "
|
||||||
|
"config file, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"The main directory where your audio recordings and annotation "
|
||||||
|
"files are stored. This helps the program find your data, "
|
||||||
|
"especially if the paths in your dataset configuration file "
|
||||||
|
"are relative."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config",
|
"--preprocess-config",
|
||||||
@ -56,7 +55,6 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|||||||
"the program how to prepare your audio data before training, such "
|
"the program how to prepare your audio data before training, such "
|
||||||
"as resampling or applying filters."
|
"as resampling or applying filters."
|
||||||
),
|
),
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config-field",
|
"--preprocess-config-field",
|
||||||
@ -67,7 +65,24 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|||||||
"here to access them. If the preprocessing settings are at the "
|
"here to access them. If the preprocessing settings are at the "
|
||||||
"top level, you don't need to specify this."
|
"top level, you don't need to specify this."
|
||||||
),
|
),
|
||||||
default="preprocess",
|
)
|
||||||
|
@click.option(
|
||||||
|
"--label-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the label generation configuration file. This file "
|
||||||
|
"contains settings for how to create labels from your "
|
||||||
|
"annotations, which the model uses to learn."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--label-config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the label generation settings are inside a nested dictionary "
|
||||||
|
"within the label configuration file, specify the key here. If "
|
||||||
|
"the settings are at the top level, leave this blank."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--target-config",
|
"--target-config",
|
||||||
@ -76,7 +91,6 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|||||||
"Path to the training target configuration file. This file "
|
"Path to the training target configuration file. This file "
|
||||||
"specifies what sounds the model should learn to predict."
|
"specifies what sounds the model should learn to predict."
|
||||||
),
|
),
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--target-config-field",
|
"--target-config-field",
|
||||||
@ -86,156 +100,101 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|||||||
"within the target configuration file, specify the key here. "
|
"within the target configuration file, specify the key here. "
|
||||||
"If the settings are at the top level, you don't need to specify this."
|
"If the settings are at the top level, you don't need to specify this."
|
||||||
),
|
),
|
||||||
default="targets",
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--postprocess-config",
|
"--force",
|
||||||
type=click.Path(exists=True),
|
is_flag=True,
|
||||||
default=DEFAULT_CONFIG_FILE,
|
help=(
|
||||||
|
"If a preprocessed file already exists, this option tells the "
|
||||||
|
"program to overwrite it with the new preprocessed data. Use "
|
||||||
|
"this if you want to re-do the preprocessing even if the files "
|
||||||
|
"already exist."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--postprocess-config-field",
|
"--num-workers",
|
||||||
type=str,
|
|
||||||
default="postprocess",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--model-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--model-config-field",
|
|
||||||
type=str,
|
|
||||||
default="model",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-workers",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
help=(
|
||||||
|
"The maximum number of computer cores to use when processing "
|
||||||
|
"your audio data. Using more cores can speed up the preprocessing, "
|
||||||
|
"but don't use more than your computer has available. By default, "
|
||||||
|
"the program will use all available cores."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
def preprocess(
|
||||||
"--val-workers",
|
dataset_config: Path,
|
||||||
type=int,
|
output: Path,
|
||||||
default=0,
|
target_config: Optional[Path] = None,
|
||||||
)
|
base_dir: Optional[Path] = None,
|
||||||
def train_command(
|
preprocess_config: Optional[Path] = None,
|
||||||
train_examples: Path,
|
label_config: Optional[Path] = None,
|
||||||
val_examples: Optional[Path] = None,
|
force: bool = False,
|
||||||
model_path: Optional[Path] = None,
|
num_workers: Optional[int] = None,
|
||||||
train_config: Path = DEFAULT_CONFIG_FILE,
|
target_config_field: Optional[str] = None,
|
||||||
train_config_field: str = "train",
|
preprocess_config_field: Optional[str] = None,
|
||||||
preprocess_config: Path = DEFAULT_CONFIG_FILE,
|
label_config_field: Optional[str] = None,
|
||||||
preprocess_config_field: str = "preprocess",
|
dataset_field: Optional[str] = None,
|
||||||
target_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
target_config_field: str = "targets",
|
|
||||||
postprocess_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
postprocess_config_field: str = "postprocess",
|
|
||||||
model_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
model_config_field: str = "model",
|
|
||||||
train_workers: int = 0,
|
|
||||||
val_workers: int = 0,
|
|
||||||
):
|
):
|
||||||
logger.info("Starting training!")
|
logger.info("Starting preprocessing.")
|
||||||
|
|
||||||
try:
|
output = Path(output)
|
||||||
target_config_loaded = load_target_config(
|
logger.info("Will save outputs to {output}", output=output)
|
||||||
path=target_config,
|
|
||||||
field=target_config_field,
|
|
||||||
)
|
|
||||||
targets = build_targets(config=target_config_loaded)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded targets info from config file {path}", path=target_config
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load target info from config file, using default"
|
|
||||||
)
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
try:
|
base_dir = base_dir or Path.cwd()
|
||||||
preprocess_config_loaded = load_preprocessing_config(
|
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||||
path=preprocess_config,
|
|
||||||
|
preprocess = (
|
||||||
|
load_preprocessing_config(
|
||||||
|
preprocess_config,
|
||||||
field=preprocess_config_field,
|
field=preprocess_config_field,
|
||||||
)
|
)
|
||||||
preprocessor = build_preprocessor(preprocess_config_loaded)
|
if preprocess_config
|
||||||
logger.debug(
|
else None
|
||||||
"Loaded preprocessor from config file {path}", path=target_config
|
|
||||||
)
|
|
||||||
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load preprocessor from config file, using default"
|
|
||||||
)
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_config_loaded = load_backbone_config(
|
|
||||||
path=model_config, field=model_config_field
|
|
||||||
)
|
|
||||||
model = build_model(
|
|
||||||
num_classes=len(targets.class_names),
|
|
||||||
config=model_config_loaded,
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
model = build_model(num_classes=len(targets.class_names))
|
|
||||||
|
|
||||||
try:
|
|
||||||
postprocess_config_loaded = load_postprocess_config(
|
|
||||||
path=postprocess_config,
|
|
||||||
field=postprocess_config_field,
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets=targets,
|
|
||||||
config=postprocess_config_loaded,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded postprocessor from file {path}",
|
|
||||||
path=train_config,
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load postprocessor config from file. Using default"
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(targets=targets)
|
|
||||||
|
|
||||||
try:
|
|
||||||
train_config_loaded = load_train_config(
|
|
||||||
path=train_config, field=train_config_field
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded training config from file {path}",
|
|
||||||
path=train_config,
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
train_config_loaded = TrainingConfig()
|
|
||||||
logger.debug("Could not load training config from file. Using default")
|
|
||||||
|
|
||||||
train_files = list_preprocessed_files(train_examples)
|
|
||||||
|
|
||||||
val_files = (
|
|
||||||
None if val_examples is None else list_preprocessed_files(val_examples)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return train(
|
target = (
|
||||||
detector=model,
|
load_target_config(
|
||||||
train_examples=train_files, # type: ignore
|
target_config,
|
||||||
val_examples=val_files, # type: ignore
|
field=target_config_field,
|
||||||
model_path=model_path,
|
)
|
||||||
|
if target_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
label = (
|
||||||
|
load_label_config(
|
||||||
|
label_config,
|
||||||
|
field=label_config_field,
|
||||||
|
)
|
||||||
|
if label_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset_from_config(
|
||||||
|
dataset_config,
|
||||||
|
field=dataset_field,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||||
|
num_examples=len(dataset),
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = build_targets(config=target)
|
||||||
|
preprocessor = build_preprocessor(config=preprocess)
|
||||||
|
labeller = build_clip_labeler(targets, config=label)
|
||||||
|
|
||||||
|
if not output.exists():
|
||||||
|
logger.debug("Creating directory {directory}", directory=output)
|
||||||
|
output.mkdir(parents=True)
|
||||||
|
|
||||||
|
logger.info("Will start preprocessing")
|
||||||
|
preprocess_annotations(
|
||||||
|
dataset,
|
||||||
|
output_dir=output,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
labeller=labeller,
|
||||||
targets=targets,
|
replace=force,
|
||||||
config=train_config_loaded,
|
max_workers=num_workers,
|
||||||
callbacks=[
|
|
||||||
ValidationMetrics(
|
|
||||||
metrics=[
|
|
||||||
DetectionAveragePrecision(),
|
|
||||||
ClassificationMeanAveragePrecision(
|
|
||||||
class_names=targets.class_names,
|
|
||||||
),
|
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
],
|
|
||||||
train_workers=train_workers,
|
|
||||||
val_workers=val_workers,
|
|
||||||
)
|
)
|
||||||
|
@ -88,7 +88,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
|||||||
KeyError: 'x'
|
KeyError: 'x'
|
||||||
"""
|
"""
|
||||||
if "." not in current_key:
|
if "." not in current_key:
|
||||||
return obj.get(current_key, {})
|
return obj[current_key]
|
||||||
|
|
||||||
current_key, rest = current_key.split(".", 1)
|
current_key, rest = current_key.split(".", 1)
|
||||||
subobj = obj[current_key]
|
subobj = obj[current_key]
|
||||||
|
@ -5,14 +5,19 @@ import uuid
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
from batdetect2.targets import get_term_from_key
|
from batdetect2 import types
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = []
|
__all__ = [
|
||||||
|
"convert_to_annotation_group",
|
||||||
|
]
|
||||||
|
|
||||||
SPECIES_TAG_KEY = "species"
|
SPECIES_TAG_KEY = "species"
|
||||||
ECHOLOCATION_EVENT = "Echolocation"
|
ECHOLOCATION_EVENT = "Echolocation"
|
||||||
@ -28,6 +33,104 @@ ClassFn = Callable[[data.Recording], int]
|
|||||||
IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
||||||
|
|
||||||
|
|
||||||
|
def get_recording_class_name(recording: data.Recording) -> str:
|
||||||
|
"""Get the class name for a recording."""
|
||||||
|
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
|
||||||
|
if tag is None:
|
||||||
|
return UNKNOWN_CLASS
|
||||||
|
return tag.value
|
||||||
|
|
||||||
|
|
||||||
|
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
|
||||||
|
"""Get the notes for a ClipAnnotation."""
|
||||||
|
all_notes = [
|
||||||
|
*annotation.notes,
|
||||||
|
*annotation.clip.recording.notes,
|
||||||
|
]
|
||||||
|
messages = [note.message for note in all_notes if note.message is not None]
|
||||||
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_annotation_group(
|
||||||
|
annotation: data.ClipAnnotation,
|
||||||
|
class_mapper: ClassMapper,
|
||||||
|
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||||
|
class_fn: ClassFn = lambda _: 0,
|
||||||
|
individual_fn: IndividualFn = lambda _: 0,
|
||||||
|
) -> types.AudioLoaderAnnotationGroup:
|
||||||
|
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
|
||||||
|
recording = annotation.clip.recording
|
||||||
|
|
||||||
|
start_times = []
|
||||||
|
end_times = []
|
||||||
|
low_freqs = []
|
||||||
|
high_freqs = []
|
||||||
|
class_ids = []
|
||||||
|
x_inds = []
|
||||||
|
y_inds = []
|
||||||
|
individual_ids = []
|
||||||
|
annotations: List[types.Annotation] = []
|
||||||
|
class_id_file = class_fn(recording)
|
||||||
|
|
||||||
|
for sound_event in annotation.sound_events:
|
||||||
|
geometry = sound_event.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
class_id = class_mapper.transform(sound_event) or -1
|
||||||
|
event = event_fn(sound_event) or ""
|
||||||
|
individual_id = individual_fn(sound_event) or -1
|
||||||
|
|
||||||
|
start_times.append(start_time)
|
||||||
|
end_times.append(end_time)
|
||||||
|
low_freqs.append(low_freq)
|
||||||
|
high_freqs.append(high_freq)
|
||||||
|
class_ids.append(class_id)
|
||||||
|
individual_ids.append(individual_id)
|
||||||
|
|
||||||
|
# NOTE: This will be computed later so we just put a placeholder
|
||||||
|
# here for now.
|
||||||
|
x_inds.append(0)
|
||||||
|
y_inds.append(0)
|
||||||
|
|
||||||
|
annotations.append(
|
||||||
|
{
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"low_freq": low_freq,
|
||||||
|
"high_freq": high_freq,
|
||||||
|
"class_prob": 1.0,
|
||||||
|
"det_prob": 1.0,
|
||||||
|
"individual": "0",
|
||||||
|
"event": event,
|
||||||
|
"class_id": class_id, # type: ignore
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(recording.path),
|
||||||
|
"duration": recording.duration,
|
||||||
|
"issues": False,
|
||||||
|
"file_path": str(recording.path),
|
||||||
|
"time_exp": recording.time_expansion,
|
||||||
|
"class_name": get_recording_class_name(recording),
|
||||||
|
"notes": get_annotation_notes(annotation),
|
||||||
|
"annotated": True,
|
||||||
|
"start_times": np.array(start_times),
|
||||||
|
"end_times": np.array(end_times),
|
||||||
|
"low_freqs": np.array(low_freqs),
|
||||||
|
"high_freqs": np.array(high_freqs),
|
||||||
|
"class_ids": np.array(class_ids),
|
||||||
|
"x_inds": np.array(x_inds),
|
||||||
|
"y_inds": np.array(y_inds),
|
||||||
|
"individual_ids": np.array(individual_ids),
|
||||||
|
"annotation": annotations,
|
||||||
|
"class_id_file": class_id_file,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Annotation(BaseModel):
|
class Annotation(BaseModel):
|
||||||
"""Annotation class to hold batdetect annotations."""
|
"""Annotation class to hold batdetect annotations."""
|
||||||
|
|
||||||
@ -92,15 +195,15 @@ def annotation_to_sound_event(
|
|||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key),
|
term=data.term_from_key(label_key),
|
||||||
value=annotation.label,
|
value=annotation.label,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(event_key),
|
term=data.term_from_key(event_key),
|
||||||
value=annotation.event,
|
value=annotation.event,
|
||||||
),
|
),
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(individual_key),
|
term=data.term_from_key(individual_key),
|
||||||
value=str(annotation.individual),
|
value=str(annotation.individual),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -125,7 +228,7 @@ def file_annotation_to_clip(
|
|||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key),
|
term=data.term_from_key(label_key),
|
||||||
value=file_annotation.label,
|
value=file_annotation.label,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -157,7 +260,7 @@ def file_annotation_to_clip_annotation(
|
|||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=get_term_from_key(label_key), value=file_annotation.label
|
term=data.term_from_key(label_key), value=file_annotation.label
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
sound_events=[
|
sound_events=[
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
from batdetect2.evaluate.evaluate import (
|
from batdetect2.evaluate.evaluate import (
|
||||||
compute_error_auc,
|
compute_error_auc,
|
||||||
)
|
|
||||||
from batdetect2.evaluate.match import (
|
|
||||||
match_predictions_and_annotations,
|
match_predictions_and_annotations,
|
||||||
match_sound_events_and_raw_predictions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"compute_error_auc",
|
"compute_error_auc",
|
||||||
"match_predictions_and_annotations",
|
"match_predictions_and_annotations",
|
||||||
"match_sound_events_and_raw_predictions",
|
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,51 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.metrics import auc, roc_curve
|
from sklearn.metrics import auc, roc_curve
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.evaluation import match_geometries
|
||||||
|
|
||||||
|
|
||||||
|
def match_predictions_and_annotations(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
clip_prediction: data.ClipPrediction,
|
||||||
|
) -> List[data.Match]:
|
||||||
|
annotated_sound_events = [
|
||||||
|
sound_event_annotation
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_sound_events = [
|
||||||
|
sound_event_prediction
|
||||||
|
for sound_event_prediction in clip_prediction.sound_events
|
||||||
|
if sound_event_prediction.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
annotated_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in annotated_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in predicted_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for id1, id2, affinity in match_geometries(
|
||||||
|
annotated_geometries,
|
||||||
|
predicted_geometries,
|
||||||
|
):
|
||||||
|
target = annotated_sound_events[id1] if id1 is not None else None
|
||||||
|
source = predicted_sound_events[id2] if id2 is not None else None
|
||||||
|
matches.append(
|
||||||
|
data.Match(source=source, target=target, affinity=affinity)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
def compute_error_auc(op_str, gt, pred, prob):
|
def compute_error_auc(op_str, gt, pred, prob):
|
||||||
|
@ -1,111 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.evaluation import match_geometries
|
|
||||||
|
|
||||||
from batdetect2.evaluate.types import Match
|
|
||||||
from batdetect2.postprocess.types import RawPrediction
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
|
||||||
|
|
||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
|
||||||
sound_events: List[data.SoundEventAnnotation],
|
|
||||||
raw_predictions: List[RawPrediction],
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> List[Match]:
|
|
||||||
target_sound_events = [
|
|
||||||
targets.transform(sound_event_annotation)
|
|
||||||
for sound_event_annotation in sound_events
|
|
||||||
if targets.filter(sound_event_annotation)
|
|
||||||
and sound_event_annotation.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
|
||||||
sound_event_annotation.sound_event.geometry
|
|
||||||
for sound_event_annotation in target_sound_events
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_geometries = [
|
|
||||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for id1, id2, affinity in match_geometries(
|
|
||||||
target_geometries,
|
|
||||||
predicted_geometries,
|
|
||||||
):
|
|
||||||
target = target_sound_events[id1] if id1 is not None else None
|
|
||||||
prediction = raw_predictions[id2] if id2 is not None else None
|
|
||||||
|
|
||||||
gt_uuid = target.uuid if target is not None else None
|
|
||||||
gt_det = target is not None
|
|
||||||
gt_class = targets.encode(target) if target is not None else None
|
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
|
||||||
|
|
||||||
class_scores = (
|
|
||||||
{
|
|
||||||
str(class_name): float(score)
|
|
||||||
for class_name, score in iterate_over_array(
|
|
||||||
prediction.class_scores
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if prediction is not None
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
matches.append(
|
|
||||||
Match(
|
|
||||||
gt_uuid=gt_uuid,
|
|
||||||
gt_det=gt_det,
|
|
||||||
gt_class=gt_class,
|
|
||||||
pred_score=pred_score,
|
|
||||||
affinity=affinity,
|
|
||||||
class_scores=class_scores,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
|
|
||||||
def match_predictions_and_annotations(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
clip_prediction: data.ClipPrediction,
|
|
||||||
) -> List[data.Match]:
|
|
||||||
annotated_sound_events = [
|
|
||||||
sound_event_annotation
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
|
||||||
if sound_event_annotation.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_sound_events = [
|
|
||||||
sound_event_prediction
|
|
||||||
for sound_event_prediction in clip_prediction.sound_events
|
|
||||||
if sound_event_prediction.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
annotated_geometries: List[data.Geometry] = [
|
|
||||||
sound_event.sound_event.geometry
|
|
||||||
for sound_event in annotated_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_geometries: List[data.Geometry] = [
|
|
||||||
sound_event.sound_event.geometry
|
|
||||||
for sound_event in predicted_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for id1, id2, affinity in match_geometries(
|
|
||||||
annotated_geometries,
|
|
||||||
predicted_geometries,
|
|
||||||
):
|
|
||||||
target = annotated_sound_events[id1] if id1 is not None else None
|
|
||||||
source = predicted_sound_events[id2] if id2 is not None else None
|
|
||||||
matches.append(
|
|
||||||
data.Match(source=source, target=target, affinity=affinity)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matches
|
|
@ -1,97 +0,0 @@
|
|||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from sklearn import metrics
|
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
|
|
||||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
|
||||||
|
|
||||||
__all__ = ["DetectionAveragePrecision"]
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecision(MetricsProtocol):
|
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
|
||||||
y_true, y_score = zip(
|
|
||||||
*[(match.gt_det, match.pred_score) for match in matches]
|
|
||||||
)
|
|
||||||
score = float(metrics.average_precision_score(y_true, y_score))
|
|
||||||
return {"detection_AP": score}
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|
||||||
def __init__(self, class_names: List[str], per_class: bool = True):
|
|
||||||
self.class_names = class_names
|
|
||||||
self.per_class = per_class
|
|
||||||
|
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
|
||||||
y_true = label_binarize(
|
|
||||||
[
|
|
||||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
|
||||||
for match in matches
|
|
||||||
],
|
|
||||||
classes=self.class_names,
|
|
||||||
)
|
|
||||||
y_pred = pd.DataFrame(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
name: match.class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
}
|
|
||||||
for match in matches
|
|
||||||
]
|
|
||||||
).fillna(0)
|
|
||||||
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
|
||||||
|
|
||||||
ret = {
|
|
||||||
"classification_mAP": float(mAP),
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.per_class:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_pred_class = y_pred[class_name]
|
|
||||||
class_ap = metrics.average_precision_score(
|
|
||||||
y_true_class,
|
|
||||||
y_pred_class,
|
|
||||||
)
|
|
||||||
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAccuracy(MetricsProtocol):
|
|
||||||
def __init__(self, class_names: List[str]):
|
|
||||||
self.class_names = class_names
|
|
||||||
|
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
|
||||||
y_true = [
|
|
||||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
|
||||||
for match in matches
|
|
||||||
]
|
|
||||||
|
|
||||||
y_pred = pd.DataFrame(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
name: match.class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
}
|
|
||||||
for match in matches
|
|
||||||
]
|
|
||||||
).fillna(0)
|
|
||||||
y_pred = y_pred.apply(
|
|
||||||
lambda row: row.idxmax()
|
|
||||||
if row.max() >= (1 - row.sum())
|
|
||||||
else "__NONE__",
|
|
||||||
axis=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
accuracy = metrics.balanced_accuracy_score(
|
|
||||||
y_true,
|
|
||||||
y_pred,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"classification_acc": float(accuracy),
|
|
||||||
}
|
|
@ -1,22 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Protocol
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MetricsProtocol",
|
|
||||||
"Match",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Match:
|
|
||||||
gt_uuid: Optional[UUID]
|
|
||||||
gt_det: bool
|
|
||||||
gt_class: Optional[str]
|
|
||||||
pred_score: float
|
|
||||||
affinity: float
|
|
||||||
class_scores: Dict[str, float]
|
|
||||||
|
|
||||||
|
|
||||||
class MetricsProtocol(Protocol):
|
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...
|
|
@ -170,8 +170,8 @@ def load_postprocess_config(
|
|||||||
def build_postprocessor(
|
def build_postprocessor(
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: Optional[PostprocessConfig] = None,
|
config: Optional[PostprocessConfig] = None,
|
||||||
max_freq: float = MAX_FREQ,
|
max_freq: int = MAX_FREQ,
|
||||||
min_freq: float = MIN_FREQ,
|
min_freq: int = MIN_FREQ,
|
||||||
) -> PostprocessorProtocol:
|
) -> PostprocessorProtocol:
|
||||||
"""Factory function to build the standard postprocessor.
|
"""Factory function to build the standard postprocessor.
|
||||||
|
|
||||||
@ -234,9 +234,9 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
recovery.
|
recovery.
|
||||||
config : PostprocessConfig
|
config : PostprocessConfig
|
||||||
Configuration object holding parameters for NMS, thresholds, etc.
|
Configuration object holding parameters for NMS, thresholds, etc.
|
||||||
min_freq : float
|
min_freq : int
|
||||||
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
max_freq : float
|
max_freq : int
|
||||||
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -246,8 +246,8 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: PostprocessConfig,
|
config: PostprocessConfig,
|
||||||
min_freq: float = MIN_FREQ,
|
min_freq: int = MIN_FREQ,
|
||||||
max_freq: float = MAX_FREQ,
|
max_freq: int = MAX_FREQ,
|
||||||
):
|
):
|
||||||
"""Initialize the Postprocessor.
|
"""Initialize the Postprocessor.
|
||||||
|
|
||||||
|
@ -32,10 +32,10 @@ from typing import List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||||
from batdetect2.targets.classes import SoundEventDecoder
|
from batdetect2.targets.classes import SoundEventDecoder
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_xr_dataset_to_raw_prediction",
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
@ -97,14 +97,18 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
det_info = detection_dataset.sel(detection=det_num)
|
det_info = detection_dataset.sel(detection=det_num)
|
||||||
|
|
||||||
geom = geometry_builder(
|
geom = geometry_builder(
|
||||||
(det_info.time, det_info.frequency),
|
(det_info.time, det_info.freq),
|
||||||
det_info.dimensions,
|
det_info.dimensions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
|
||||||
detections.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=det_info.scores,
|
detection_score=det_info.score,
|
||||||
geometry=geom,
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
class_scores=det_info.classes,
|
class_scores=det_info.classes,
|
||||||
features=det_info.features,
|
features=det_info.features,
|
||||||
)
|
)
|
||||||
@ -240,7 +244,14 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
"""
|
"""
|
||||||
sound_event = data.SoundEvent(
|
sound_event = data.SoundEvent(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
geometry=raw_prediction.geometry,
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
raw_prediction.start_time,
|
||||||
|
raw_prediction.low_freq,
|
||||||
|
raw_prediction.end_time,
|
||||||
|
raw_prediction.high_freq,
|
||||||
|
]
|
||||||
|
),
|
||||||
features=get_prediction_features(raw_prediction.features),
|
features=get_prediction_features(raw_prediction.features),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -322,7 +333,7 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
|||||||
),
|
),
|
||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
for feat_name, value in iterate_over_array(features)
|
for feat_name, value in _iterate_over_array(features)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -383,6 +394,13 @@ def get_class_tags(
|
|||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
def _iterate_over_array(array: xr.DataArray):
|
||||||
|
dim_name = array.dims[0]
|
||||||
|
coords = array.coords[dim_name]
|
||||||
|
for value, coord in zip(array.values, coords.values):
|
||||||
|
yield coord, float(value)
|
||||||
|
|
||||||
|
|
||||||
def _iterate_sorted(array: xr.DataArray):
|
def _iterate_sorted(array: xr.DataArray):
|
||||||
dim_name = array.dims[0]
|
dim_name = array.dims[0]
|
||||||
coords = array.coords[dim_name].values
|
coords = array.coords[dim_name].values
|
||||||
|
@ -47,9 +47,14 @@ class RawPrediction(NamedTuple):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
geometry: data.Geometry
|
start_time : float
|
||||||
The recovered estimated geometry of the detected sound event.
|
Start time of the recovered bounding box in seconds.
|
||||||
Usually a bounding box.
|
end_time : float
|
||||||
|
End time of the recovered bounding box in seconds.
|
||||||
|
low_freq : float
|
||||||
|
Lowest frequency of the recovered bounding box in Hz.
|
||||||
|
high_freq : float
|
||||||
|
Highest frequency of the recovered bounding box in Hz.
|
||||||
detection_score : float
|
detection_score : float
|
||||||
The confidence score associated with this detection, typically from
|
The confidence score associated with this detection, typically from
|
||||||
the detection heatmap peak.
|
the detection heatmap peak.
|
||||||
@ -62,7 +67,10 @@ class RawPrediction(NamedTuple):
|
|||||||
detection location. Indexed by a 'feature' coordinate.
|
detection location. Indexed by a 'feature' coordinate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
geometry: data.Geometry
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
low_freq: float
|
||||||
|
high_freq: float
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: xr.DataArray
|
class_scores: xr.DataArray
|
||||||
features: xr.DataArray
|
features: xr.DataArray
|
||||||
|
@ -24,7 +24,6 @@ object is via the `build_targets` or `load_targets` functions.
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
@ -158,9 +157,7 @@ class TargetConfig(BaseConfig):
|
|||||||
|
|
||||||
filtering: Optional[FilterConfig] = None
|
filtering: Optional[FilterConfig] = None
|
||||||
transforms: Optional[TransformConfig] = None
|
transforms: Optional[TransformConfig] = None
|
||||||
classes: ClassesConfig = Field(
|
classes: ClassesConfig
|
||||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
|
||||||
)
|
|
||||||
roi: Optional[ROIConfig] = None
|
roi: Optional[ROIConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@ -441,84 +438,6 @@ class Targets(TargetProtocol):
|
|||||||
return self._roi_mapper.recover_roi(pos, dims)
|
return self._roi_mapper.recover_roi(pos, dims)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSES = [
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis mystacinus")],
|
|
||||||
name="myomys",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis alcathoe")],
|
|
||||||
name="myoalc",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Eptesicus serotinus")],
|
|
||||||
name="eptser",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus nathusii")],
|
|
||||||
name="pipnat",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Barbastellus barbastellus")],
|
|
||||||
name="barbar",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis nattereri")],
|
|
||||||
name="myonat",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis daubentonii")],
|
|
||||||
name="myodau",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis brandtii")],
|
|
||||||
name="myobra",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
|
||||||
name="pippip",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis bechsteinii")],
|
|
||||||
name="myobec",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
|
||||||
name="pippyg",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
|
||||||
name="rhihip",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
|
||||||
name="nyclei",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
|
||||||
name="rhifer",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Plecotus auritus")],
|
|
||||||
name="pleaur",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Nyctalus noctula")],
|
|
||||||
name="nycnoc",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Plecotus austriacus")],
|
|
||||||
name="pleaus",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
|
|
||||||
classes=DEFAULT_CLASSES,
|
|
||||||
generic_class=[TagInfo(value="Bat")],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||||
filtering=FilterConfig(
|
filtering=FilterConfig(
|
||||||
rules=[
|
rules=[
|
||||||
@ -536,7 +455,79 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
classes=DEFAULT_CLASSES_CONFIG,
|
classes=ClassesConfig(
|
||||||
|
classes=[
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis mystacinus")],
|
||||||
|
name="myomys",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis alcathoe")],
|
||||||
|
name="myoalc",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Eptesicus serotinus")],
|
||||||
|
name="eptser",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus nathusii")],
|
||||||
|
name="pipnat",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Barbastellus barbastellus")],
|
||||||
|
name="barbar",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis nattereri")],
|
||||||
|
name="myonat",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis daubentonii")],
|
||||||
|
name="myodau",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis brandtii")],
|
||||||
|
name="myobra",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
||||||
|
name="pippip",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis bechsteinii")],
|
||||||
|
name="myobec",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
||||||
|
name="pippyg",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
||||||
|
name="rhihip",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||||
|
name="nyclei",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||||
|
name="rhifer",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Plecotus auritus")],
|
||||||
|
name="pleaur",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Nyctalus noctula")],
|
||||||
|
name="nycnoc",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Plecotus austriacus")],
|
||||||
|
name="pleaus",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
generic_class=[TagInfo(value="Bat")],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ def contains_tags(
|
|||||||
False otherwise.
|
False otherwise.
|
||||||
"""
|
"""
|
||||||
sound_event_tags = set(sound_event_annotation.tags)
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
return tags <= sound_event_tags
|
return tags < sound_event_tags
|
||||||
|
|
||||||
|
|
||||||
def does_not_have_tags(
|
def does_not_have_tags(
|
||||||
|
@ -20,27 +20,14 @@ scaling factors) is managed by the `ROIConfig`. This module separates the
|
|||||||
handled in `batdetect2.targets.classes`.
|
handled in `batdetect2.targets.classes`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, Tuple
|
from typing import List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data, geometry
|
||||||
|
from soundevent.geometry.operations import Positions
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
Positions = Literal[
|
|
||||||
"bottom-left",
|
|
||||||
"bottom-right",
|
|
||||||
"top-left",
|
|
||||||
"top-right",
|
|
||||||
"center-left",
|
|
||||||
"center-right",
|
|
||||||
"top-center",
|
|
||||||
"bottom-center",
|
|
||||||
"center",
|
|
||||||
"centroid",
|
|
||||||
"point_on_surface",
|
|
||||||
]
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"ROIConfig",
|
"ROIConfig",
|
||||||
@ -255,8 +242,6 @@ class BBoxEncoder(ROITargetMapper):
|
|||||||
Tuple[float, float]
|
Tuple[float, float]
|
||||||
Reference position (time, frequency).
|
Reference position (time, frequency).
|
||||||
"""
|
"""
|
||||||
from soundevent import geometry
|
|
||||||
|
|
||||||
return geometry.get_geometry_point(geom, position=self.position)
|
return geometry.get_geometry_point(geom, position=self.position)
|
||||||
|
|
||||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||||
@ -275,8 +260,6 @@ class BBoxEncoder(ROITargetMapper):
|
|||||||
np.ndarray
|
np.ndarray
|
||||||
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
||||||
"""
|
"""
|
||||||
from soundevent import geometry
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
geom
|
geom
|
||||||
)
|
)
|
||||||
@ -325,8 +308,8 @@ class BBoxEncoder(ROITargetMapper):
|
|||||||
width, height = dims
|
width, height = dims
|
||||||
return _build_bounding_box(
|
return _build_bounding_box(
|
||||||
pos,
|
pos,
|
||||||
duration=float(width) / self.time_scale,
|
duration=width / self.time_scale,
|
||||||
bandwidth=float(height) / self.frequency_scale,
|
bandwidth=height / self.frequency_scale,
|
||||||
position=self.position,
|
position=self.position,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -438,16 +421,14 @@ def _build_bounding_box(
|
|||||||
ValueError
|
ValueError
|
||||||
If `position` is not a recognized value or format.
|
If `position` is not a recognized value or format.
|
||||||
"""
|
"""
|
||||||
time, freq = map(float, pos)
|
time, freq = pos
|
||||||
duration = max(0, duration)
|
|
||||||
bandwidth = max(0, bandwidth)
|
|
||||||
if position in ["center", "centroid", "point_on_surface"]:
|
if position in ["center", "centroid", "point_on_surface"]:
|
||||||
return data.BoundingBox(
|
return data.BoundingBox(
|
||||||
coordinates=[
|
coordinates=[
|
||||||
max(time - duration / 2, 0),
|
time - duration / 2,
|
||||||
max(freq - bandwidth / 2, 0),
|
freq - bandwidth / 2,
|
||||||
max(time + duration / 2, 0),
|
time + duration / 2,
|
||||||
max(freq + bandwidth / 2, 0),
|
freq + bandwidth / 2,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -473,9 +454,9 @@ def _build_bounding_box(
|
|||||||
|
|
||||||
return data.BoundingBox(
|
return data.BoundingBox(
|
||||||
coordinates=[
|
coordinates=[
|
||||||
max(0, start_time),
|
start_time,
|
||||||
max(0, low_freq),
|
low_freq,
|
||||||
max(0, start_time + duration),
|
start_time + duration,
|
||||||
max(0, low_freq + bandwidth),
|
low_freq + bandwidth,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -14,47 +14,28 @@ from batdetect2.train.augmentations import (
|
|||||||
warp_spectrogram,
|
warp_spectrogram,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper, select_subclip
|
from batdetect2.train.clips import build_clipper, select_subclip
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||||
TrainerConfig,
|
|
||||||
TrainingConfig,
|
|
||||||
load_train_config,
|
|
||||||
)
|
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import load_label_config
|
||||||
from batdetect2.train.losses import (
|
from batdetect2.train.losses import LossFunction, build_loss
|
||||||
ClassificationLossConfig,
|
|
||||||
DetectionLossConfig,
|
|
||||||
LossConfig,
|
|
||||||
LossFunction,
|
|
||||||
SizeLossConfig,
|
|
||||||
build_loss,
|
|
||||||
)
|
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
generate_train_example,
|
generate_train_example,
|
||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.train import (
|
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||||
build_train_dataset,
|
|
||||||
build_val_dataset,
|
|
||||||
train,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
"ClassificationLossConfig",
|
|
||||||
"DetectionLossConfig",
|
|
||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
"LossConfig",
|
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
"RandomExampleSource",
|
"RandomExampleSource",
|
||||||
"SizeLossConfig",
|
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
"TrainerConfig",
|
||||||
@ -63,15 +44,13 @@ __all__ = [
|
|||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"build_augmentations",
|
"build_augmentations",
|
||||||
"build_clip_labeler",
|
|
||||||
"build_clipper",
|
"build_clipper",
|
||||||
"build_loss",
|
"build_loss",
|
||||||
"build_train_dataset",
|
|
||||||
"build_val_dataset",
|
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"list_preprocessed_files",
|
"list_preprocessed_files",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
|
"load_trainer_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_examples",
|
"mix_examples",
|
||||||
@ -79,6 +58,5 @@ __all__ = [
|
|||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
"train",
|
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
@ -1,52 +1,30 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from soundevent import data
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
from batdetect2.postprocess import PostprocessorProtocol
|
||||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.types import ModelOutput
|
||||||
from batdetect2.train.types import ModelOutput
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self, metrics: List[MetricsProtocol]):
|
def __init__(self, postprocessor: PostprocessorProtocol):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.postprocessor = postprocessor
|
||||||
if len(metrics) == 0:
|
self.predictions = []
|
||||||
raise ValueError("At least one metric needs to be provided")
|
|
||||||
|
|
||||||
self.matches: List[Match] = []
|
|
||||||
self.metrics = metrics
|
|
||||||
|
|
||||||
def on_validation_epoch_end(
|
|
||||||
self,
|
|
||||||
trainer: Trainer,
|
|
||||||
pl_module: LightningModule,
|
|
||||||
) -> None:
|
|
||||||
metrics = {}
|
|
||||||
for metric in self.metrics:
|
|
||||||
metrics.update(metric(self.matches).items())
|
|
||||||
|
|
||||||
pl_module.log_dict(metrics)
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
|
||||||
|
|
||||||
def on_validation_epoch_start(
|
def on_validation_epoch_start(
|
||||||
self,
|
self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.matches = []
|
self.predictions = []
|
||||||
return super().on_validation_epoch_start(trainer, pl_module)
|
return super().on_validation_epoch_start(trainer, pl_module)
|
||||||
|
|
||||||
def on_validation_batch_end( # type: ignore
|
def on_validation_batch_end( # type: ignore
|
||||||
self,
|
self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: TrainingModule,
|
pl_module: LightningModule,
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
batch: TrainExample,
|
batch: TrainExample,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
@ -54,73 +32,24 @@ class ValidationMetrics(Callback):
|
|||||||
) -> None:
|
) -> None:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
|
||||||
dataset = dataloaders.dataset
|
dataset = dataloaders.dataset
|
||||||
assert isinstance(dataset, LabeledDataset)
|
assert isinstance(dataset, LabeledDataset)
|
||||||
|
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||||
|
|
||||||
clip_annotations = [
|
# clip_prediction = postprocess_model_outputs(
|
||||||
_get_subclip(
|
# outputs,
|
||||||
dataset.get_clip_annotation(example_id),
|
# clips=[clip_annotation.clip],
|
||||||
start_time=start_time.item(),
|
# classes=self.class_names,
|
||||||
end_time=end_time.item(),
|
# decoder=self.decoder,
|
||||||
targets=pl_module.targets,
|
# config=self.config.postprocessing,
|
||||||
)
|
# )[0]
|
||||||
for example_id, start_time, end_time in zip(
|
#
|
||||||
batch.idx,
|
# matches = match_predictions_and_annotations(
|
||||||
batch.start_time,
|
# clip_annotation,
|
||||||
batch.end_time,
|
# clip_prediction,
|
||||||
)
|
# )
|
||||||
]
|
#
|
||||||
|
# self.validation_predictions.extend(matches)
|
||||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
# return super().on_validation_batch_end(
|
||||||
|
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||||
raw_predictions = pl_module.postprocessor.get_raw_predictions(
|
# )
|
||||||
outputs,
|
|
||||||
clips,
|
|
||||||
)
|
|
||||||
|
|
||||||
for clip_annotation, clip_predictions in zip(
|
|
||||||
clip_annotations, raw_predictions
|
|
||||||
):
|
|
||||||
self.matches.extend(
|
|
||||||
match_sound_events_and_raw_predictions(
|
|
||||||
sound_events=clip_annotation.sound_events,
|
|
||||||
raw_predictions=clip_predictions,
|
|
||||||
targets=pl_module.targets,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_in_subclip(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
) -> bool:
|
|
||||||
time, _ = targets.get_position(sound_event_annotation)
|
|
||||||
return start_time <= time <= end_time
|
|
||||||
|
|
||||||
|
|
||||||
def _get_subclip(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return data.ClipAnnotation(
|
|
||||||
clip=data.Clip(
|
|
||||||
recording=clip_annotation.clip.recording,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
),
|
|
||||||
sound_events=[
|
|
||||||
sound_event_annotation
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
|
||||||
if _is_in_subclip(
|
|
||||||
sound_event_annotation,
|
|
||||||
targets,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
@ -69,15 +69,12 @@ class Clipper(ClipperProtocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_clipper(
|
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
|
||||||
config: Optional[ClipingConfig] = None,
|
|
||||||
random: Optional[bool] = None,
|
|
||||||
) -> ClipperProtocol:
|
|
||||||
config = config or ClipingConfig()
|
config = config or ClipingConfig()
|
||||||
return Clipper(
|
return Clipper(
|
||||||
duration=config.duration,
|
duration=config.duration,
|
||||||
max_empty=config.max_empty,
|
max_empty=config.max_empty,
|
||||||
random=config.random if random else False,
|
random=config.random,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
@ -23,29 +23,8 @@ class OptimizerConfig(BaseConfig):
|
|||||||
t_max: int = 100
|
t_max: int = 100
|
||||||
|
|
||||||
|
|
||||||
class TrainerConfig(BaseConfig):
|
|
||||||
accelerator: str = "auto"
|
|
||||||
accumulate_grad_batches: int = 1
|
|
||||||
deterministic: bool = True
|
|
||||||
check_val_every_n_epoch: int = 1
|
|
||||||
devices: Union[str, int] = "auto"
|
|
||||||
enable_checkpointing: bool = True
|
|
||||||
gradient_clip_val: Optional[float] = None
|
|
||||||
limit_train_batches: Optional[Union[int, float]] = None
|
|
||||||
limit_test_batches: Optional[Union[int, float]] = None
|
|
||||||
limit_val_batches: Optional[Union[int, float]] = None
|
|
||||||
log_every_n_steps: Optional[int] = None
|
|
||||||
max_epochs: Optional[int] = 200
|
|
||||||
min_epochs: Optional[int] = None
|
|
||||||
max_steps: Optional[int] = None
|
|
||||||
min_steps: Optional[int] = None
|
|
||||||
max_time: Optional[str] = None
|
|
||||||
precision: Optional[str] = None
|
|
||||||
val_check_interval: Optional[Union[int, float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(BaseConfig):
|
||||||
batch_size: int = 8
|
batch_size: int = 32
|
||||||
|
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
|
|
||||||
@ -57,11 +36,9 @@ class TrainingConfig(BaseConfig):
|
|||||||
|
|
||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
|
|
||||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
path: data.PathLike,
|
path: PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
) -> TrainingConfig:
|
) -> TrainingConfig:
|
||||||
return load_config(path, schema=TrainingConfig, field=field)
|
return load_config(path, schema=TrainingConfig, field=field)
|
||||||
|
@ -42,8 +42,8 @@ class LabeledDataset(Dataset):
|
|||||||
class_heatmap=self.to_tensor(dataset["class"]),
|
class_heatmap=self.to_tensor(dataset["class"]),
|
||||||
size_heatmap=self.to_tensor(dataset["size"]),
|
size_heatmap=self.to_tensor(dataset["size"]),
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
start_time=torch.tensor(start_time),
|
start_time=start_time,
|
||||||
end_time=torch.tensor(end_time),
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -81,24 +81,17 @@ class LabeledDataset(Dataset):
|
|||||||
array: xr.DataArray,
|
array: xr.DataArray,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.nan_to_num(
|
return torch.tensor(array.values.astype(dtype))
|
||||||
torch.tensor(array.values.astype(dtype)),
|
|
||||||
nan=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def list_preprocessed_files(
|
def list_preprocessed_files(
|
||||||
directory: data.PathLike, extension: str = ".nc"
|
directory: data.PathLike, extension: str = ".nc"
|
||||||
) -> List[Path]:
|
) -> Sequence[Path]:
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
|
||||||
|
|
||||||
class RandomExampleSource:
|
class RandomExampleSource:
|
||||||
def __init__(
|
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
|
||||||
self,
|
|
||||||
filenames: List[data.PathLike],
|
|
||||||
clipper: ClipperProtocol,
|
|
||||||
):
|
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.clipper = clipper
|
self.clipper = clipper
|
||||||
|
|
||||||
|
@ -23,13 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
|||||||
defined in `LabelConfig`.
|
defined in `LabelConfig`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from loguru import logger
|
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
@ -52,6 +52,8 @@ __all__ = [
|
|||||||
SIZE_DIMENSION = "dimension"
|
SIZE_DIMENSION = "dimension"
|
||||||
"""Dimension name for the size heatmap."""
|
"""Dimension name for the size heatmap."""
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LabelConfig(BaseConfig):
|
class LabelConfig(BaseConfig):
|
||||||
"""Configuration parameters for heatmap generation.
|
"""Configuration parameters for heatmap generation.
|
||||||
@ -135,27 +137,12 @@ def generate_clip_label(
|
|||||||
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
||||||
heatmaps for this clip.
|
heatmaps for this clip.
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
|
||||||
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
|
||||||
uuid=clip_annotation.uuid,
|
|
||||||
num=len(clip_annotation.sound_events)
|
|
||||||
)
|
|
||||||
|
|
||||||
sound_events = []
|
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
|
||||||
if not targets.filter(sound_event_annotation):
|
|
||||||
logger.debug(
|
|
||||||
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
|
|
||||||
sound_event=sound_event_annotation,
|
|
||||||
tags=sound_event_annotation.tags,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
sound_events.append(targets.transform(sound_event_annotation))
|
|
||||||
|
|
||||||
return generate_heatmaps(
|
return generate_heatmaps(
|
||||||
sound_events,
|
(
|
||||||
|
targets.transform(sound_event_annotation)
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if targets.filter(sound_event_annotation)
|
||||||
|
),
|
||||||
spec=spec,
|
spec=spec,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
target_sigma=config.sigma,
|
target_sigma=config.sigma,
|
||||||
|
@ -40,9 +40,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
# NOTE: Ignore detector and loss from hyperparameter saving
|
self.save_hyperparameters()
|
||||||
# as they are nn.Module and should be saved regardless.
|
|
||||||
self.save_hyperparameters(ignore=["detector", "loss"])
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.detector(spec)
|
return self.detector(spec)
|
||||||
@ -51,25 +49,21 @@ class TrainingModule(L.LightningModule):
|
|||||||
outputs = self.forward(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
|
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/train", losses.total, logger=True)
|
self.log("train/loss/detection", losses.total, logger=True)
|
||||||
self.log("size_loss/train", losses.total, logger=True)
|
self.log("train/loss/size", losses.total, logger=True)
|
||||||
self.log("classification_loss/train", losses.total, logger=True)
|
self.log("train/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
return losses.total
|
return losses.total
|
||||||
|
|
||||||
def validation_step( # type: ignore
|
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||||
self, batch: TrainExample, batch_idx: int
|
|
||||||
) -> ModelOutput:
|
|
||||||
outputs = self.forward(batch.spec)
|
outputs = self.forward(batch.spec)
|
||||||
losses = self.loss(outputs, batch)
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
|
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||||
self.log("detection_loss/val", losses.total, logger=True)
|
self.log("val/loss/detection", losses.total, logger=True)
|
||||||
self.log("size_loss/val", losses.total, logger=True)
|
self.log("val/loss/size", losses.total, logger=True)
|
||||||
self.log("classification_loss/val", losses.total, logger=True)
|
self.log("val/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||||
|
@ -1,147 +1,68 @@
|
|||||||
from typing import List, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import LightningModule
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch import Trainer
|
||||||
from soundevent import data
|
from soundevent.data import PathLike
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train.augmentations import (
|
|
||||||
build_augmentations,
|
|
||||||
)
|
|
||||||
from batdetect2.train.clips import build_clipper
|
|
||||||
from batdetect2.train.config import TrainingConfig
|
|
||||||
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
|
|
||||||
from batdetect2.train.lightning import TrainingModule
|
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"train",
|
"train",
|
||||||
"build_val_dataset",
|
"TrainerConfig",
|
||||||
"build_train_dataset",
|
"load_trainer_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerConfig(BaseConfig):
|
||||||
|
accelerator: str = "auto"
|
||||||
|
accumulate_grad_batches: int = 1
|
||||||
|
deterministic: bool = True
|
||||||
|
check_val_every_n_epoch: int = 1
|
||||||
|
devices: Union[str, int] = "auto"
|
||||||
|
enable_checkpointing: bool = True
|
||||||
|
gradient_clip_val: Optional[float] = None
|
||||||
|
limit_train_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_test_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_val_batches: Optional[Union[int, float]] = None
|
||||||
|
log_every_n_steps: Optional[int] = None
|
||||||
|
max_epochs: Optional[int] = None
|
||||||
|
min_epochs: Optional[int] = 100
|
||||||
|
max_steps: Optional[int] = None
|
||||||
|
min_steps: Optional[int] = None
|
||||||
|
max_time: Optional[str] = None
|
||||||
|
precision: Optional[str] = None
|
||||||
|
reload_dataloaders_every_n_epochs: Optional[int] = None
|
||||||
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
||||||
|
return load_config(path, schema=TrainerConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
detector: DetectionModel,
|
module: LightningModule,
|
||||||
train_examples: List[data.PathLike],
|
train_dataset: LabeledDataset,
|
||||||
targets: Optional[TargetProtocol] = None,
|
trainer_config: Optional[TrainerConfig] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
dev_run: bool = False,
|
||||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
overfit_batches: bool = False,
|
||||||
val_examples: Optional[List[data.PathLike]] = None,
|
profiler: Optional[str] = None,
|
||||||
config: Optional[TrainingConfig] = None,
|
):
|
||||||
callbacks: Optional[List[Callback]] = None,
|
trainer_config = trainer_config or TrainerConfig()
|
||||||
model_path: Optional[data.PathLike] = None,
|
|
||||||
train_workers: int = 0,
|
|
||||||
val_workers: int = 0,
|
|
||||||
**trainer_kwargs,
|
|
||||||
) -> None:
|
|
||||||
config = config or TrainingConfig()
|
|
||||||
if model_path is None:
|
|
||||||
if preprocessor is None:
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
if targets is None:
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
if postprocessor is None:
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = build_loss(config.loss)
|
|
||||||
|
|
||||||
module = TrainingModule(
|
|
||||||
detector=detector,
|
|
||||||
loss=loss,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
learning_rate=config.optimizer.learning_rate,
|
|
||||||
t_max=config.optimizer.t_max,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
|
||||||
|
|
||||||
train_dataset = build_train_dataset(
|
|
||||||
train_examples,
|
|
||||||
preprocessor=module.preprocessor,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
**config.trainer.model_dump(exclude_none=True),
|
**trainer_config.model_dump(
|
||||||
callbacks=callbacks,
|
exclude_unset=True,
|
||||||
**trainer_kwargs,
|
exclude_none=True,
|
||||||
|
),
|
||||||
|
fast_dev_run=dev_run,
|
||||||
|
overfit_batches=overfit_batches,
|
||||||
|
profiler=profiler,
|
||||||
)
|
)
|
||||||
|
train_loader = DataLoader(
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=module.config.train.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=train_workers,
|
num_workers=7,
|
||||||
)
|
)
|
||||||
|
trainer.fit(module, train_dataloaders=train_loader)
|
||||||
val_dataloader = None
|
|
||||||
if val_examples:
|
|
||||||
val_dataset = build_val_dataset(
|
|
||||||
val_examples,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=val_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.fit(
|
|
||||||
module,
|
|
||||||
train_dataloaders=train_dataloader,
|
|
||||||
val_dataloaders=val_dataloader,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_train_dataset(
|
|
||||||
examples: List[data.PathLike],
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
config: Optional[TrainingConfig] = None,
|
|
||||||
) -> LabeledDataset:
|
|
||||||
config = config or TrainingConfig()
|
|
||||||
|
|
||||||
clipper = build_clipper(config.cliping, random=True)
|
|
||||||
|
|
||||||
random_example_source = RandomExampleSource(
|
|
||||||
examples,
|
|
||||||
clipper=clipper,
|
|
||||||
)
|
|
||||||
|
|
||||||
augmentations = build_augmentations(
|
|
||||||
preprocessor,
|
|
||||||
config=config.augmentations,
|
|
||||||
example_source=random_example_source,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LabeledDataset(
|
|
||||||
examples,
|
|
||||||
clipper=clipper,
|
|
||||||
augmentation=augmentations,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_val_dataset(
|
|
||||||
examples: List[data.PathLike],
|
|
||||||
config: Optional[TrainingConfig] = None,
|
|
||||||
train: bool = True,
|
|
||||||
) -> LabeledDataset:
|
|
||||||
config = config or TrainingConfig()
|
|
||||||
clipper = build_clipper(config.cliping, random=train)
|
|
||||||
return LabeledDataset(examples, clipper=clipper)
|
|
||||||
|
@ -57,8 +57,8 @@ class TrainExample(NamedTuple):
|
|||||||
class_heatmap: torch.Tensor
|
class_heatmap: torch.Tensor
|
||||||
size_heatmap: torch.Tensor
|
size_heatmap: torch.Tensor
|
||||||
idx: torch.Tensor
|
idx: torch.Tensor
|
||||||
start_time: torch.Tensor
|
start_time: float
|
||||||
end_time: torch.Tensor
|
end_time: float
|
||||||
|
|
||||||
|
|
||||||
class Losses(NamedTuple):
|
class Losses(NamedTuple):
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
|
||||||
|
|
||||||
|
|
||||||
def extend_width(
|
def extend_width(
|
||||||
@ -60,10 +59,3 @@ def adjust_width(
|
|||||||
for index in range(dims)
|
for index in range(dims)
|
||||||
]
|
]
|
||||||
return array[tuple(slices)]
|
return array[tuple(slices)]
|
||||||
|
|
||||||
|
|
||||||
def iterate_over_array(array: xr.DataArray):
|
|
||||||
dim_name = array.dims[0]
|
|
||||||
coords = array.coords[dim_name]
|
|
||||||
for value, coord in zip(array.values, coords.values):
|
|
||||||
yield coord, float(value)
|
|
||||||
|
@ -7,8 +7,6 @@ import pytest
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
@ -385,27 +383,3 @@ def sample_labeller(
|
|||||||
sample_targets: TargetProtocol,
|
sample_targets: TargetProtocol,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
return build_clip_labeler(sample_targets)
|
return build_clip_labeler(sample_targets)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def example_dataset(example_data_dir: Path) -> DatasetConfig:
|
|
||||||
return DatasetConfig(
|
|
||||||
name="test dataset",
|
|
||||||
description="test dataset",
|
|
||||||
sources=[
|
|
||||||
BatDetect2FilesAnnotations(
|
|
||||||
name="example annotations",
|
|
||||||
audio_dir=example_data_dir / "audio",
|
|
||||||
annotations_dir=example_data_dir / "anns",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def example_annotations(
|
|
||||||
example_dataset: DatasetConfig,
|
|
||||||
) -> List[data.ClipAnnotation]:
|
|
||||||
annotations = load_dataset(example_dataset)
|
|
||||||
assert len(annotations) == 3
|
|
||||||
return annotations
|
|
||||||
|
@ -254,11 +254,11 @@ class TestLoadBatDetect2Files:
|
|||||||
assert clip_ann.clip.recording.duration == 5.0
|
assert clip_ann.clip.recording.duration == 5.0
|
||||||
assert len(clip_ann.sound_events) == 1
|
assert len(clip_ann.sound_events) == 1
|
||||||
assert clip_ann.notes[0].message == "Standard notes."
|
assert clip_ann.notes[0].message == "Standard notes."
|
||||||
clip_tag = data.find_tag(clip_ann.tags, "Class")
|
clip_tag = data.find_tag(clip_ann.tags, "class")
|
||||||
assert clip_tag is not None
|
assert clip_tag is not None
|
||||||
assert clip_tag.value == "Myotis"
|
assert clip_tag.value == "Myotis"
|
||||||
|
|
||||||
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
|
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "class")
|
||||||
assert recording_tag is not None
|
assert recording_tag is not None
|
||||||
assert recording_tag.value == "Myotis"
|
assert recording_tag.value == "Myotis"
|
||||||
|
|
||||||
@ -271,15 +271,15 @@ class TestLoadBatDetect2Files:
|
|||||||
40000,
|
40000,
|
||||||
]
|
]
|
||||||
|
|
||||||
se_class_tag = data.find_tag(se_ann.tags, "Class")
|
se_class_tag = data.find_tag(se_ann.tags, "class")
|
||||||
assert se_class_tag is not None
|
assert se_class_tag is not None
|
||||||
assert se_class_tag.value == "Myotis"
|
assert se_class_tag.value == "Myotis"
|
||||||
|
|
||||||
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
|
se_event_tag = data.find_tag(se_ann.tags, "event")
|
||||||
assert se_event_tag is not None
|
assert se_event_tag is not None
|
||||||
assert se_event_tag.value == "Echolocation"
|
assert se_event_tag.value == "Echolocation"
|
||||||
|
|
||||||
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
|
se_individual_tag = data.find_tag(se_ann.tags, "individual")
|
||||||
assert se_individual_tag is not None
|
assert se_individual_tag is not None
|
||||||
assert se_individual_tag.value == "0"
|
assert se_individual_tag.value == "0"
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class TestLoadBatDetect2Merged:
|
|||||||
assert clip_ann.clip.recording.duration == 5.0
|
assert clip_ann.clip.recording.duration == 5.0
|
||||||
assert len(clip_ann.sound_events) == 1
|
assert len(clip_ann.sound_events) == 1
|
||||||
|
|
||||||
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
|
clip_class_tag = data.find_tag(clip_ann.tags, "class")
|
||||||
assert clip_class_tag is not None
|
assert clip_class_tag is not None
|
||||||
assert clip_class_tag.value == "Myotis"
|
assert clip_class_tag.value == "Myotis"
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ def sample_detection_dataset() -> xr.Dataset:
|
|||||||
expected_freqs = np.array([300, 200])
|
expected_freqs = np.array([300, 200])
|
||||||
detection_coords = {
|
detection_coords = {
|
||||||
"time": ("detection", expected_times),
|
"time": ("detection", expected_times),
|
||||||
"frequency": ("detection", expected_freqs),
|
"freq": ("detection", expected_freqs),
|
||||||
}
|
}
|
||||||
|
|
||||||
scores_data = np.array([0.9, 0.8], dtype=np.float64)
|
scores_data = np.array([0.9, 0.8], dtype=np.float64)
|
||||||
@ -106,7 +106,7 @@ def sample_detection_dataset() -> xr.Dataset:
|
|||||||
scores_data,
|
scores_data,
|
||||||
coords=detection_coords,
|
coords=detection_coords,
|
||||||
dims=["detection"],
|
dims=["detection"],
|
||||||
name="score",
|
name="scores",
|
||||||
)
|
)
|
||||||
|
|
||||||
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
|
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
|
||||||
@ -183,7 +183,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
|||||||
)
|
)
|
||||||
return xr.Dataset(
|
return xr.Dataset(
|
||||||
{
|
{
|
||||||
"score": scores,
|
"scores": scores,
|
||||||
"dimensions": dimensions,
|
"dimensions": dimensions,
|
||||||
"classes": classes,
|
"classes": classes,
|
||||||
"features": features,
|
"features": features,
|
||||||
@ -206,14 +206,10 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
)
|
)
|
||||||
pred1 = RawPrediction(
|
pred1 = RawPrediction(
|
||||||
detection_score=0.9,
|
detection_score=0.9,
|
||||||
geometry=data.BoundingBox(
|
start_time=20 - 7 / 2,
|
||||||
coordinates=[
|
end_time=20 + 7 / 2,
|
||||||
20 - 7 / 2,
|
low_freq=300 - 16 / 2,
|
||||||
300 - 16 / 2,
|
high_freq=300 + 16 / 2,
|
||||||
20 + 7 / 2,
|
|
||||||
300 + 16 / 2,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
class_scores=pred1_classes,
|
class_scores=pred1_classes,
|
||||||
features=pred1_features,
|
features=pred1_features,
|
||||||
)
|
)
|
||||||
@ -228,14 +224,10 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
)
|
)
|
||||||
pred2 = RawPrediction(
|
pred2 = RawPrediction(
|
||||||
detection_score=0.8,
|
detection_score=0.8,
|
||||||
geometry=data.BoundingBox(
|
start_time=10 - 3 / 2,
|
||||||
coordinates=[
|
end_time=10 + 3 / 2,
|
||||||
10 - 3 / 2,
|
low_freq=200 - 12 / 2,
|
||||||
200 - 12 / 2,
|
high_freq=200 + 12 / 2,
|
||||||
10 + 3 / 2,
|
|
||||||
200 + 12 / 2,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
class_scores=pred2_classes,
|
class_scores=pred2_classes,
|
||||||
features=pred2_features,
|
features=pred2_features,
|
||||||
)
|
)
|
||||||
@ -250,14 +242,10 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
)
|
)
|
||||||
pred3 = RawPrediction(
|
pred3 = RawPrediction(
|
||||||
detection_score=0.15,
|
detection_score=0.15,
|
||||||
geometry=data.BoundingBox(
|
start_time=5.0,
|
||||||
coordinates=[
|
end_time=6.0,
|
||||||
5.0,
|
low_freq=50.0,
|
||||||
50.0,
|
high_freq=60.0,
|
||||||
6.0,
|
|
||||||
60.0,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
class_scores=pred3_classes,
|
class_scores=pred3_classes,
|
||||||
features=pred3_features,
|
features=pred3_features,
|
||||||
)
|
)
|
||||||
@ -279,12 +267,10 @@ def test_convert_xr_dataset_basic(
|
|||||||
assert isinstance(pred1, RawPrediction)
|
assert isinstance(pred1, RawPrediction)
|
||||||
assert pred1.detection_score == 0.9
|
assert pred1.detection_score == 0.9
|
||||||
|
|
||||||
assert pred1.geometry.coordinates == [
|
assert pred1.start_time == 20 - 7 / 2
|
||||||
20 - 7 / 2,
|
assert pred1.end_time == 20 + 7 / 2
|
||||||
300 - 16 / 2,
|
assert pred1.low_freq == 300 - 16 / 2
|
||||||
20 + 7 / 2,
|
assert pred1.high_freq == 300 + 16 / 2
|
||||||
300 + 16 / 2,
|
|
||||||
]
|
|
||||||
xr.testing.assert_allclose(
|
xr.testing.assert_allclose(
|
||||||
pred1.class_scores,
|
pred1.class_scores,
|
||||||
sample_detection_dataset["classes"].sel(detection=0),
|
sample_detection_dataset["classes"].sel(detection=0),
|
||||||
@ -297,12 +283,10 @@ def test_convert_xr_dataset_basic(
|
|||||||
assert isinstance(pred2, RawPrediction)
|
assert isinstance(pred2, RawPrediction)
|
||||||
assert pred2.detection_score == 0.8
|
assert pred2.detection_score == 0.8
|
||||||
|
|
||||||
assert pred2.geometry.coordinates == [
|
assert pred2.start_time == 10 - 3 / 2
|
||||||
10 - 3 / 2,
|
assert pred2.end_time == 10 + 3 / 2
|
||||||
200 - 12 / 2,
|
assert pred2.low_freq == 200 - 12 / 2
|
||||||
10 + 3 / 2,
|
assert pred2.high_freq == 200 + 12 / 2
|
||||||
200 + 12 / 2,
|
|
||||||
]
|
|
||||||
xr.testing.assert_allclose(
|
xr.testing.assert_allclose(
|
||||||
pred2.class_scores,
|
pred2.class_scores,
|
||||||
sample_detection_dataset["classes"].sel(detection=1),
|
sample_detection_dataset["classes"].sel(detection=1),
|
||||||
@ -347,7 +331,15 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
assert isinstance(se, data.SoundEvent)
|
assert isinstance(se, data.SoundEvent)
|
||||||
assert se.recording == sample_recording
|
assert se.recording == sample_recording
|
||||||
assert isinstance(se.geometry, data.BoundingBox)
|
assert isinstance(se.geometry, data.BoundingBox)
|
||||||
assert se.geometry == raw_pred.geometry
|
np.testing.assert_allclose(
|
||||||
|
se.geometry.coordinates,
|
||||||
|
[
|
||||||
|
raw_pred.start_time,
|
||||||
|
raw_pred.low_freq,
|
||||||
|
raw_pred.end_time,
|
||||||
|
raw_pred.high_freq,
|
||||||
|
],
|
||||||
|
)
|
||||||
assert len(se.features) == len(raw_pred.features)
|
assert len(se.features) == len(raw_pred.features)
|
||||||
|
|
||||||
feat_dict = {f.term.name: f.value for f in se.features}
|
feat_dict = {f.term.name: f.value for f in se.features}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Union
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -307,7 +307,7 @@ def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
|
|||||||
def test_resize_spectrogram(
|
def test_resize_spectrogram(
|
||||||
sample_spec: xr.DataArray,
|
sample_spec: xr.DataArray,
|
||||||
height: int,
|
height: int,
|
||||||
resize_factor: Union[float, None],
|
resize_factor: float | None,
|
||||||
expected_freq_size: int,
|
expected_freq_size: int,
|
||||||
expected_time_factor: float,
|
expected_time_factor: float,
|
||||||
):
|
):
|
||||||
|
@ -4,7 +4,6 @@ from typing import Callable, List, Set
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.targets.filtering import (
|
from batdetect2.targets.filtering import (
|
||||||
FilterConfig,
|
FilterConfig,
|
||||||
FilterRule,
|
FilterRule,
|
||||||
@ -177,34 +176,3 @@ rules:
|
|||||||
filter_result = load_filter_from_config(test_config_path)
|
filter_result = load_filter_from_config(test_config_path)
|
||||||
annotation = create_annotation(["tag1", "tag3"])
|
annotation = create_annotation(["tag1", "tag3"])
|
||||||
assert filter_result(annotation) is False
|
assert filter_result(annotation) is False
|
||||||
|
|
||||||
|
|
||||||
def test_default_filtering_over_example_dataset(
|
|
||||||
example_annotations: List[data.ClipAnnotation],
|
|
||||||
):
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
clip1 = example_annotations[0]
|
|
||||||
clip2 = example_annotations[1]
|
|
||||||
clip3 = example_annotations[2]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
sum(
|
|
||||||
[targets.filter(sound_event) for sound_event in clip1.sound_events]
|
|
||||||
)
|
|
||||||
== 9
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
sum(
|
|
||||||
[targets.filter(sound_event) for sound_event in clip2.sound_events]
|
|
||||||
)
|
|
||||||
== 15
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
sum(
|
|
||||||
[targets.filter(sound_event) for sound_event in clip3.sound_events]
|
|
||||||
)
|
|
||||||
== 20
|
|
||||||
)
|
|
||||||
|
@ -9,8 +9,8 @@ from batdetect2.preprocess.types import PreprocessorProtocol
|
|||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
add_echo,
|
add_echo,
|
||||||
mix_examples,
|
mix_examples,
|
||||||
|
select_subclip,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import select_subclip
|
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
subclip = select_subclip(original, start=0, span=100)
|
subclip = select_subclip(original, width=100)
|
||||||
|
|
||||||
assert subclip["spectrogram"].shape[1] == 100
|
assert subclip["spectrogram"].shape[1] == 100
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ def test_add_echo_after_subclip(
|
|||||||
|
|
||||||
assert original.sizes["time"] > 512
|
assert original.sizes["time"] > 512
|
||||||
|
|
||||||
subclip = select_subclip(original, start=0, span=512)
|
subclip = select_subclip(original, width=512)
|
||||||
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
|
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
|
||||||
|
|
||||||
assert with_echo.sizes["time"] == 512
|
assert with_echo.sizes["time"] == 512
|
||||||
|
Loading…
Reference in New Issue
Block a user