Compare commits

..

7 Commits

Author SHA1 Message Date
mbsantiago
fdab0860fd Add num workers to cli 2025-04-30 23:40:26 +01:00
mbsantiago
3b5623ddca create train cli command 2025-04-30 23:27:38 +01:00
mbsantiago
2913fa59a4 Fix tests 2025-04-30 22:51:49 +01:00
mbsantiago
9c8b8fb200 Create metrics 2025-04-30 22:51:33 +01:00
mbsantiago
bc86c94f8e Adding evaluation callback 2025-04-25 17:12:57 +01:00
mbsantiago
9106b9f408 Removing stale functions from legacy module 2025-04-25 17:12:39 +01:00
mbsantiago
899f74efd5 Wrote the main train function 2025-04-24 10:00:18 +01:00
32 changed files with 1199 additions and 576 deletions

View File

@ -0,0 +1,196 @@
from pathlib import Path
from typing import Optional
import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import load_label_config, preprocess_annotations
from batdetect2.train.labels import build_clip_labeler
__all__ = ["preprocess"]
@cli.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.argument(
"output",
type=click.Path(),
)
@click.option(
"--dataset-field",
type=str,
help=(
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
)
@click.option(
"--preprocess-config",
type=click.Path(exists=True),
help=(
"Path to the preprocessing configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
)
@click.option(
"--preprocess-config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
)
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
)
@click.option(
"--force",
is_flag=True,
help=(
"If a preprocessed file already exists, this option tells the "
"program to overwrite it with the new preprocessed data. Use "
"this if you want to re-do the preprocessing even if the files "
"already exist."
),
)
@click.option(
"--num-workers",
type=int,
help=(
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
)
def preprocess(
dataset_config: Path,
output: Path,
target_config: Optional[Path] = None,
base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None,
label_config: Optional[Path] = None,
force: bool = False,
num_workers: Optional[int] = None,
target_config_field: Optional[str] = None,
preprocess_config_field: Optional[str] = None,
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None,
):
logger.info("Starting preprocessing.")
output = Path(output)
logger.info("Will save outputs to {output}", output=output)
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = (
load_preprocessing_config(
preprocess_config,
field=preprocess_config_field,
)
if preprocess_config
else None
)
target = (
load_target_config(
target_config,
field=target_config_field,
)
if target_config
else None
)
label = (
load_label_config(
label_config,
field=label_config_field,
)
if label_config
else None
)
dataset = load_dataset_from_config(
dataset_config,
field=dataset_field,
base_dir=base_dir,
)
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
max_workers=num_workers,
)

View File

@ -5,47 +5,48 @@ import click
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import build_model
from batdetect2.models.backbones import load_backbone_config
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import load_label_config, preprocess_annotations from batdetect2.train import train
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import list_preprocessed_files
__all__ = ["train"] __all__ = [
"train_command",
]
DEFAULT_CONFIG_FILE = Path("config.yaml")
@cli.group() @cli.command(name="train")
def train(): ... @click.option(
"--train-examples",
type=click.Path(exists=True),
@train.command() required=True,
@click.argument( )
"dataset_config", @click.option("--val-examples", type=click.Path(exists=True))
@click.option(
"--model-path",
type=click.Path(exists=True), type=click.Path(exists=True),
) )
@click.argument( @click.option(
"output", "--train-config",
type=click.Path(), type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
) )
@click.option( @click.option(
"--dataset-field", "--train-config-field",
type=str, type=str,
help=( default="train",
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
) )
@click.option( @click.option(
"--preprocess-config", "--preprocess-config",
@ -55,6 +56,7 @@ def train(): ...
"the program how to prepare your audio data before training, such " "the program how to prepare your audio data before training, such "
"as resampling or applying filters." "as resampling or applying filters."
), ),
default=DEFAULT_CONFIG_FILE,
) )
@click.option( @click.option(
"--preprocess-config-field", "--preprocess-config-field",
@ -65,24 +67,7 @@ def train(): ...
"here to access them. If the preprocessing settings are at the " "here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this." "top level, you don't need to specify this."
), ),
) default="preprocess",
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
) )
@click.option( @click.option(
"--target-config", "--target-config",
@ -91,6 +76,7 @@ def train(): ...
"Path to the training target configuration file. This file " "Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict." "specifies what sounds the model should learn to predict."
), ),
default=DEFAULT_CONFIG_FILE,
) )
@click.option( @click.option(
"--target-config-field", "--target-config-field",
@ -100,101 +86,156 @@ def train(): ...
"within the target configuration file, specify the key here. " "within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this." "If the settings are at the top level, you don't need to specify this."
), ),
default="targets",
) )
@click.option( @click.option(
"--force", "--postprocess-config",
is_flag=True, type=click.Path(exists=True),
help=( default=DEFAULT_CONFIG_FILE,
"If a preprocessed file already exists, this option tells the "
"program to overwrite it with the new preprocessed data. Use "
"this if you want to re-do the preprocessing even if the files "
"already exist."
),
) )
@click.option( @click.option(
"--num-workers", "--postprocess-config-field",
type=str,
default="postprocess",
)
@click.option(
"--model-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--model-config-field",
type=str,
default="model",
)
@click.option(
"--train-workers",
type=int, type=int,
help=( default=0,
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
) )
def preprocess( @click.option(
dataset_config: Path, "--val-workers",
output: Path, type=int,
target_config: Optional[Path] = None, default=0,
base_dir: Optional[Path] = None, )
preprocess_config: Optional[Path] = None, def train_command(
label_config: Optional[Path] = None, train_examples: Path,
force: bool = False, val_examples: Optional[Path] = None,
num_workers: Optional[int] = None, model_path: Optional[Path] = None,
target_config_field: Optional[str] = None, train_config: Path = DEFAULT_CONFIG_FILE,
preprocess_config_field: Optional[str] = None, train_config_field: str = "train",
label_config_field: Optional[str] = None, preprocess_config: Path = DEFAULT_CONFIG_FILE,
dataset_field: Optional[str] = None, preprocess_config_field: str = "preprocess",
target_config: Path = DEFAULT_CONFIG_FILE,
target_config_field: str = "targets",
postprocess_config: Path = DEFAULT_CONFIG_FILE,
postprocess_config_field: str = "postprocess",
model_config: Path = DEFAULT_CONFIG_FILE,
model_config_field: str = "model",
train_workers: int = 0,
val_workers: int = 0,
): ):
logger.info("Starting preprocessing.") logger.info("Starting training!")
output = Path(output) try:
logger.info("Will save outputs to {output}", output=output) target_config_loaded = load_target_config(
path=target_config,
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = (
load_preprocessing_config(
preprocess_config,
field=preprocess_config_field,
)
if preprocess_config
else None
)
target = (
load_target_config(
target_config,
field=target_config_field, field=target_config_field,
) )
if target_config targets = build_targets(config=target_config_loaded)
else None logger.debug(
) "Loaded targets info from config file {path}", path=target_config
label = (
load_label_config(
label_config,
field=label_config_field,
) )
if label_config except IOError:
else None logger.debug(
"Could not load target info from config file, using default"
)
targets = build_targets()
try:
preprocess_config_loaded = load_preprocessing_config(
path=preprocess_config,
field=preprocess_config_field,
)
preprocessor = build_preprocessor(preprocess_config_loaded)
logger.debug(
"Loaded preprocessor from config file {path}", path=target_config
)
except IOError:
logger.debug(
"Could not load preprocessor from config file, using default"
)
preprocessor = build_preprocessor()
try:
model_config_loaded = load_backbone_config(
path=model_config, field=model_config_field
)
model = build_model(
num_classes=len(targets.class_names),
config=model_config_loaded,
)
except IOError:
model = build_model(num_classes=len(targets.class_names))
try:
postprocess_config_loaded = load_postprocess_config(
path=postprocess_config,
field=postprocess_config_field,
)
postprocessor = build_postprocessor(
targets=targets,
config=postprocess_config_loaded,
)
logger.debug(
"Loaded postprocessor from file {path}",
path=train_config,
)
except IOError:
logger.debug(
"Could not load postprocessor config from file. Using default"
)
postprocessor = build_postprocessor(targets=targets)
try:
train_config_loaded = load_train_config(
path=train_config, field=train_config_field
)
logger.debug(
"Loaded training config from file {path}",
path=train_config,
)
except IOError:
train_config_loaded = TrainingConfig()
logger.debug("Could not load training config from file. Using default")
train_files = list_preprocessed_files(train_examples)
val_files = (
None if val_examples is None else list_preprocessed_files(val_examples)
) )
dataset = load_dataset_from_config( return train(
dataset_config, detector=model,
field=dataset_field, train_examples=train_files, # type: ignore
base_dir=base_dir, val_examples=val_files, # type: ignore
) model_path=model_path,
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
dataset,
output_dir=output,
preprocessor=preprocessor, preprocessor=preprocessor,
labeller=labeller, postprocessor=postprocessor,
replace=force, targets=targets,
max_workers=num_workers, config=train_config_loaded,
callbacks=[
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names,
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
],
train_workers=train_workers,
val_workers=val_workers,
) )

View File

@ -88,7 +88,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
KeyError: 'x' KeyError: 'x'
""" """
if "." not in current_key: if "." not in current_key:
return obj[current_key] return obj.get(current_key, {})
current_key, rest = current_key.split(".", 1) current_key, rest = current_key.split(".", 1)
subobj = obj[current_key] subobj = obj[current_key]

View File

@ -5,19 +5,14 @@ import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2 import types from batdetect2.targets import get_term_from_key
PathLike = Union[Path, str, os.PathLike] PathLike = Union[Path, str, os.PathLike]
__all__ = [ __all__ = []
"convert_to_annotation_group",
]
SPECIES_TAG_KEY = "species" SPECIES_TAG_KEY = "species"
ECHOLOCATION_EVENT = "Echolocation" ECHOLOCATION_EVENT = "Echolocation"
@ -33,104 +28,6 @@ ClassFn = Callable[[data.Recording], int]
IndividualFn = Callable[[data.SoundEventAnnotation], int] IndividualFn = Callable[[data.SoundEventAnnotation], int]
def get_recording_class_name(recording: data.Recording) -> str:
"""Get the class name for a recording."""
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
if tag is None:
return UNKNOWN_CLASS
return tag.value
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
"""Get the notes for a ClipAnnotation."""
all_notes = [
*annotation.notes,
*annotation.clip.recording.notes,
]
messages = [note.message for note in all_notes if note.message is not None]
return "\n".join(messages)
def convert_to_annotation_group(
annotation: data.ClipAnnotation,
class_mapper: ClassMapper,
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
class_fn: ClassFn = lambda _: 0,
individual_fn: IndividualFn = lambda _: 0,
) -> types.AudioLoaderAnnotationGroup:
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
recording = annotation.clip.recording
start_times = []
end_times = []
low_freqs = []
high_freqs = []
class_ids = []
x_inds = []
y_inds = []
individual_ids = []
annotations: List[types.Annotation] = []
class_id_file = class_fn(recording)
for sound_event in annotation.sound_events:
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
class_id = class_mapper.transform(sound_event) or -1
event = event_fn(sound_event) or ""
individual_id = individual_fn(sound_event) or -1
start_times.append(start_time)
end_times.append(end_time)
low_freqs.append(low_freq)
high_freqs.append(high_freq)
class_ids.append(class_id)
individual_ids.append(individual_id)
# NOTE: This will be computed later so we just put a placeholder
# here for now.
x_inds.append(0)
y_inds.append(0)
annotations.append(
{
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": 1.0,
"det_prob": 1.0,
"individual": "0",
"event": event,
"class_id": class_id, # type: ignore
}
)
return {
"id": str(recording.path),
"duration": recording.duration,
"issues": False,
"file_path": str(recording.path),
"time_exp": recording.time_expansion,
"class_name": get_recording_class_name(recording),
"notes": get_annotation_notes(annotation),
"annotated": True,
"start_times": np.array(start_times),
"end_times": np.array(end_times),
"low_freqs": np.array(low_freqs),
"high_freqs": np.array(high_freqs),
"class_ids": np.array(class_ids),
"x_inds": np.array(x_inds),
"y_inds": np.array(y_inds),
"individual_ids": np.array(individual_ids),
"annotation": annotations,
"class_id_file": class_id_file,
}
class Annotation(BaseModel): class Annotation(BaseModel):
"""Annotation class to hold batdetect annotations.""" """Annotation class to hold batdetect annotations."""
@ -195,15 +92,15 @@ def annotation_to_sound_event(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag( data.Tag(
term=data.term_from_key(label_key), term=get_term_from_key(label_key),
value=annotation.label, value=annotation.label,
), ),
data.Tag( data.Tag(
term=data.term_from_key(event_key), term=get_term_from_key(event_key),
value=annotation.event, value=annotation.event,
), ),
data.Tag( data.Tag(
term=data.term_from_key(individual_key), term=get_term_from_key(individual_key),
value=str(annotation.individual), value=str(annotation.individual),
), ),
], ],
@ -228,7 +125,7 @@ def file_annotation_to_clip(
time_expansion=file_annotation.time_exp, time_expansion=file_annotation.time_exp,
tags=[ tags=[
data.Tag( data.Tag(
term=data.term_from_key(label_key), term=get_term_from_key(label_key),
value=file_annotation.label, value=file_annotation.label,
) )
], ],
@ -260,7 +157,7 @@ def file_annotation_to_clip_annotation(
notes=notes, notes=notes,
tags=[ tags=[
data.Tag( data.Tag(
term=data.term_from_key(label_key), value=file_annotation.label term=get_term_from_key(label_key), value=file_annotation.label
) )
], ],
sound_events=[ sound_events=[

View File

@ -1,9 +1,13 @@
from batdetect2.evaluate.evaluate import ( from batdetect2.evaluate.evaluate import (
compute_error_auc, compute_error_auc,
)
from batdetect2.evaluate.match import (
match_predictions_and_annotations, match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
) )
__all__ = [ __all__ = [
"compute_error_auc", "compute_error_auc",
"match_predictions_and_annotations", "match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
] ]

View File

@ -1,51 +1,6 @@
from typing import List
import numpy as np import numpy as np
from sklearn.metrics import auc, roc_curve from sklearn.metrics import auc, roc_curve
from soundevent import data
from soundevent.evaluation import match_geometries
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
) -> List[data.Match]:
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None
matches.append(
data.Match(source=source, target=target, affinity=affinity)
)
return matches
def compute_error_auc(op_str, gt, pred, prob): def compute_error_auc(op_str, gt, pred, prob):

View File

@ -0,0 +1,111 @@
from typing import List
from soundevent import data
from soundevent.evaluation import match_geometries
from batdetect2.evaluate.types import Match
from batdetect2.postprocess.types import RawPrediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
def match_sound_events_and_raw_predictions(
sound_events: List[data.SoundEventAnnotation],
raw_predictions: List[RawPrediction],
targets: TargetProtocol,
) -> List[Match]:
target_sound_events = [
targets.transform(sound_event_annotation)
for sound_event_annotation in sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
]
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in target_sound_events
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
matches = []
for id1, id2, affinity in match_geometries(
target_geometries,
predicted_geometries,
):
target = target_sound_events[id1] if id1 is not None else None
prediction = raw_predictions[id2] if id2 is not None else None
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None
gt_class = targets.encode(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
class_scores = (
{
str(class_name): float(score)
for class_name, score in iterate_over_array(
prediction.class_scores
)
}
if prediction is not None
else {}
)
matches.append(
Match(
gt_uuid=gt_uuid,
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
affinity=affinity,
class_scores=class_scores,
)
)
return matches
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
) -> List[data.Match]:
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None
matches.append(
data.Match(source=source, target=target, affinity=affinity)
)
return matches

View File

@ -0,0 +1,97 @@
from typing import Dict, List
import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import Match, MetricsProtocol
__all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
)
score = float(metrics.average_precision_score(y_true, y_score))
return {"detection_AP": score}
class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str], per_class: bool = True):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
],
classes=self.class_names,
)
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = {
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name]
class_ap = metrics.average_precision_score(
y_true_class,
y_pred_class,
)
ret[f"classification_AP/{class_name}"] = float(class_ap)
return ret
class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
]
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
y_pred = y_pred.apply(
lambda row: row.idxmax()
if row.max() >= (1 - row.sum())
else "__NONE__",
axis=1,
)
accuracy = metrics.balanced_accuracy_score(
y_true,
y_pred,
)
return {
"classification_acc": float(accuracy),
}

View File

@ -0,0 +1,22 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from uuid import UUID
__all__ = [
"MetricsProtocol",
"Match",
]
@dataclass
class Match:
gt_uuid: Optional[UUID]
gt_det: bool
gt_class: Optional[str]
pred_score: float
affinity: float
class_scores: Dict[str, float]
class MetricsProtocol(Protocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...

View File

@ -170,8 +170,8 @@ def load_postprocess_config(
def build_postprocessor( def build_postprocessor(
targets: TargetProtocol, targets: TargetProtocol,
config: Optional[PostprocessConfig] = None, config: Optional[PostprocessConfig] = None,
max_freq: int = MAX_FREQ, max_freq: float = MAX_FREQ,
min_freq: int = MIN_FREQ, min_freq: float = MIN_FREQ,
) -> PostprocessorProtocol: ) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor. """Factory function to build the standard postprocessor.
@ -234,9 +234,9 @@ class Postprocessor(PostprocessorProtocol):
recovery. recovery.
config : PostprocessConfig config : PostprocessConfig
Configuration object holding parameters for NMS, thresholds, etc. Configuration object holding parameters for NMS, thresholds, etc.
min_freq : int min_freq : float
Minimum frequency (Hz) assumed for the model output's frequency axis. Minimum frequency (Hz) assumed for the model output's frequency axis.
max_freq : int max_freq : float
Maximum frequency (Hz) assumed for the model output's frequency axis. Maximum frequency (Hz) assumed for the model output's frequency axis.
""" """
@ -246,8 +246,8 @@ class Postprocessor(PostprocessorProtocol):
self, self,
targets: TargetProtocol, targets: TargetProtocol,
config: PostprocessConfig, config: PostprocessConfig,
min_freq: int = MIN_FREQ, min_freq: float = MIN_FREQ,
max_freq: int = MAX_FREQ, max_freq: float = MAX_FREQ,
): ):
"""Initialize the Postprocessor. """Initialize the Postprocessor.

View File

@ -32,10 +32,10 @@ from typing import List, Optional
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
__all__ = [ __all__ = [
"convert_xr_dataset_to_raw_prediction", "convert_xr_dataset_to_raw_prediction",
@ -97,18 +97,14 @@ def convert_xr_dataset_to_raw_prediction(
det_info = detection_dataset.sel(detection=det_num) det_info = detection_dataset.sel(detection=det_num)
geom = geometry_builder( geom = geometry_builder(
(det_info.time, det_info.freq), (det_info.time, det_info.frequency),
det_info.dimensions, det_info.dimensions,
) )
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
detections.append( detections.append(
RawPrediction( RawPrediction(
detection_score=det_info.score, detection_score=det_info.scores,
start_time=start_time, geometry=geom,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_scores=det_info.classes, class_scores=det_info.classes,
features=det_info.features, features=det_info.features,
) )
@ -244,14 +240,7 @@ def convert_raw_prediction_to_sound_event_prediction(
""" """
sound_event = data.SoundEvent( sound_event = data.SoundEvent(
recording=recording, recording=recording,
geometry=data.BoundingBox( geometry=raw_prediction.geometry,
coordinates=[
raw_prediction.start_time,
raw_prediction.low_freq,
raw_prediction.end_time,
raw_prediction.high_freq,
]
),
features=get_prediction_features(raw_prediction.features), features=get_prediction_features(raw_prediction.features),
) )
@ -333,7 +322,7 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
), ),
value=value, value=value,
) )
for feat_name, value in _iterate_over_array(features) for feat_name, value in iterate_over_array(features)
] ]
@ -394,13 +383,6 @@ def get_class_tags(
return tags return tags
def _iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)
def _iterate_sorted(array: xr.DataArray): def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0] dim_name = array.dims[0]
coords = array.coords[dim_name].values coords = array.coords[dim_name].values

View File

@ -47,14 +47,9 @@ class RawPrediction(NamedTuple):
Attributes Attributes
---------- ----------
start_time : float geometry: data.Geometry
Start time of the recovered bounding box in seconds. The recovered estimated geometry of the detected sound event.
end_time : float Usually a bounding box.
End time of the recovered bounding box in seconds.
low_freq : float
Lowest frequency of the recovered bounding box in Hz.
high_freq : float
Highest frequency of the recovered bounding box in Hz.
detection_score : float detection_score : float
The confidence score associated with this detection, typically from The confidence score associated with this detection, typically from
the detection heatmap peak. the detection heatmap peak.
@ -67,10 +62,7 @@ class RawPrediction(NamedTuple):
detection location. Indexed by a 'feature' coordinate. detection location. Indexed by a 'feature' coordinate.
""" """
start_time: float geometry: data.Geometry
end_time: float
low_freq: float
high_freq: float
detection_score: float detection_score: float
class_scores: xr.DataArray class_scores: xr.DataArray
features: xr.DataArray features: xr.DataArray

View File

@ -24,6 +24,7 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
@ -157,7 +158,9 @@ class TargetConfig(BaseConfig):
filtering: Optional[FilterConfig] = None filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None transforms: Optional[TransformConfig] = None
classes: ClassesConfig classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG
)
roi: Optional[ROIConfig] = None roi: Optional[ROIConfig] = None
@ -438,6 +441,84 @@ class Targets(TargetProtocol):
return self._roi_mapper.recover_roi(pos, dims) return self._roi_mapper.recover_roi(pos, dims)
DEFAULT_CLASSES = [
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
),
TargetClass(
tags=[TagInfo(value="Myotis alcathoe")],
name="myoalc",
),
TargetClass(
tags=[TagInfo(value="Eptesicus serotinus")],
name="eptser",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus nathusii")],
name="pipnat",
),
TargetClass(
tags=[TagInfo(value="Barbastellus barbastellus")],
name="barbar",
),
TargetClass(
tags=[TagInfo(value="Myotis nattereri")],
name="myonat",
),
TargetClass(
tags=[TagInfo(value="Myotis daubentonii")],
name="myodau",
),
TargetClass(
tags=[TagInfo(value="Myotis brandtii")],
name="myobra",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pipistrellus")],
name="pippip",
),
TargetClass(
tags=[TagInfo(value="Myotis bechsteinii")],
name="myobec",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pygmaeus")],
name="pippyg",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus hipposideros")],
name="rhihip",
),
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
name="pleaur",
),
TargetClass(
tags=[TagInfo(value="Nyctalus noctula")],
name="nycnoc",
),
TargetClass(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
]
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
classes=DEFAULT_CLASSES,
generic_class=[TagInfo(value="Bat")],
)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
filtering=FilterConfig( filtering=FilterConfig(
rules=[ rules=[
@ -455,79 +536,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
), ),
] ]
), ),
classes=ClassesConfig( classes=DEFAULT_CLASSES_CONFIG,
classes=[
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
),
TargetClass(
tags=[TagInfo(value="Myotis alcathoe")],
name="myoalc",
),
TargetClass(
tags=[TagInfo(value="Eptesicus serotinus")],
name="eptser",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus nathusii")],
name="pipnat",
),
TargetClass(
tags=[TagInfo(value="Barbastellus barbastellus")],
name="barbar",
),
TargetClass(
tags=[TagInfo(value="Myotis nattereri")],
name="myonat",
),
TargetClass(
tags=[TagInfo(value="Myotis daubentonii")],
name="myodau",
),
TargetClass(
tags=[TagInfo(value="Myotis brandtii")],
name="myobra",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pipistrellus")],
name="pippip",
),
TargetClass(
tags=[TagInfo(value="Myotis bechsteinii")],
name="myobec",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pygmaeus")],
name="pippyg",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus hipposideros")],
name="rhihip",
),
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
name="pleaur",
),
TargetClass(
tags=[TagInfo(value="Nyctalus noctula")],
name="nycnoc",
),
TargetClass(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
],
generic_class=[TagInfo(value="Bat")],
),
) )

View File

@ -106,7 +106,7 @@ def contains_tags(
False otherwise. False otherwise.
""" """
sound_event_tags = set(sound_event_annotation.tags) sound_event_tags = set(sound_event_annotation.tags)
return tags < sound_event_tags return tags <= sound_event_tags
def does_not_have_tags( def does_not_have_tags(

View File

@ -20,14 +20,27 @@ scaling factors) is managed by the `ROIConfig`. This module separates the
handled in `batdetect2.targets.classes`. handled in `batdetect2.targets.classes`.
""" """
from typing import List, Optional, Protocol, Tuple from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np import numpy as np
from soundevent import data, geometry from soundevent import data
from soundevent.geometry.operations import Positions
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
Positions = Literal[
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
__all__ = [ __all__ = [
"ROITargetMapper", "ROITargetMapper",
"ROIConfig", "ROIConfig",
@ -242,6 +255,8 @@ class BBoxEncoder(ROITargetMapper):
Tuple[float, float] Tuple[float, float]
Reference position (time, frequency). Reference position (time, frequency).
""" """
from soundevent import geometry
return geometry.get_geometry_point(geom, position=self.position) return geometry.get_geometry_point(geom, position=self.position)
def get_roi_size(self, geom: data.Geometry) -> np.ndarray: def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
@ -260,6 +275,8 @@ class BBoxEncoder(ROITargetMapper):
np.ndarray np.ndarray
A 1D NumPy array: `[scaled_width, scaled_height]`. A 1D NumPy array: `[scaled_width, scaled_height]`.
""" """
from soundevent import geometry
start_time, low_freq, end_time, high_freq = geometry.compute_bounds( start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom geom
) )
@ -308,8 +325,8 @@ class BBoxEncoder(ROITargetMapper):
width, height = dims width, height = dims
return _build_bounding_box( return _build_bounding_box(
pos, pos,
duration=width / self.time_scale, duration=float(width) / self.time_scale,
bandwidth=height / self.frequency_scale, bandwidth=float(height) / self.frequency_scale,
position=self.position, position=self.position,
) )
@ -421,14 +438,16 @@ def _build_bounding_box(
ValueError ValueError
If `position` is not a recognized value or format. If `position` is not a recognized value or format.
""" """
time, freq = pos time, freq = map(float, pos)
duration = max(0, duration)
bandwidth = max(0, bandwidth)
if position in ["center", "centroid", "point_on_surface"]: if position in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox( return data.BoundingBox(
coordinates=[ coordinates=[
time - duration / 2, max(time - duration / 2, 0),
freq - bandwidth / 2, max(freq - bandwidth / 2, 0),
time + duration / 2, max(time + duration / 2, 0),
freq + bandwidth / 2, max(freq + bandwidth / 2, 0),
] ]
) )
@ -454,9 +473,9 @@ def _build_bounding_box(
return data.BoundingBox( return data.BoundingBox(
coordinates=[ coordinates=[
start_time, max(0, start_time),
low_freq, max(0, low_freq),
start_time + duration, max(0, start_time + duration),
low_freq + bandwidth, max(0, low_freq + bandwidth),
] ]
) )

View File

@ -14,28 +14,47 @@ from batdetect2.train.augmentations import (
warp_spectrogram, warp_spectrogram,
) )
from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import TrainingConfig, load_train_config from batdetect2.train.config import (
TrainerConfig,
TrainingConfig,
load_train_config,
)
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
TrainExample, TrainExample,
list_preprocessed_files, list_preprocessed_files,
) )
from batdetect2.train.labels import load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.losses import LossFunction, build_loss from batdetect2.train.losses import (
ClassificationLossConfig,
DetectionLossConfig,
LossConfig,
LossFunction,
SizeLossConfig,
build_loss,
)
from batdetect2.train.preprocess import ( from batdetect2.train.preprocess import (
generate_train_example, generate_train_example,
preprocess_annotations, preprocess_annotations,
) )
from batdetect2.train.train import TrainerConfig, load_trainer_config, train from batdetect2.train.train import (
build_train_dataset,
build_val_dataset,
train,
)
__all__ = [ __all__ = [
"AugmentationsConfig", "AugmentationsConfig",
"ClassificationLossConfig",
"DetectionLossConfig",
"EchoAugmentationConfig", "EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig", "FrequencyMaskAugmentationConfig",
"LabeledDataset", "LabeledDataset",
"LossConfig",
"LossFunction", "LossFunction",
"RandomExampleSource", "RandomExampleSource",
"SizeLossConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainExample", "TrainExample",
"TrainerConfig", "TrainerConfig",
@ -44,13 +63,15 @@ __all__ = [
"WarpAugmentationConfig", "WarpAugmentationConfig",
"add_echo", "add_echo",
"build_augmentations", "build_augmentations",
"build_clip_labeler",
"build_clipper", "build_clipper",
"build_loss", "build_loss",
"build_train_dataset",
"build_val_dataset",
"generate_train_example", "generate_train_example",
"list_preprocessed_files", "list_preprocessed_files",
"load_label_config", "load_label_config",
"load_train_config", "load_train_config",
"load_trainer_config",
"mask_frequency", "mask_frequency",
"mask_time", "mask_time",
"mix_examples", "mix_examples",
@ -58,5 +79,6 @@ __all__ = [
"scale_volume", "scale_volume",
"select_subclip", "select_subclip",
"train", "train",
"train",
"warp_spectrogram", "warp_spectrogram",
] ]

View File

@ -1,30 +1,52 @@
from typing import List
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.postprocess import PostprocessorProtocol from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.types import Match, MetricsProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.types import ModelOutput from batdetect2.train.lightning import TrainingModule
from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self, postprocessor: PostprocessorProtocol): def __init__(self, metrics: List[MetricsProtocol]):
super().__init__() super().__init__()
self.postprocessor = postprocessor
self.predictions = [] if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided")
self.matches: List[Match] = []
self.metrics = metrics
def on_validation_epoch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
metrics = {}
for metric in self.metrics:
metrics.update(metric(self.matches).items())
pl_module.log_dict(metrics)
return super().on_validation_epoch_end(trainer, pl_module)
def on_validation_epoch_start( def on_validation_epoch_start(
self, self,
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self.predictions = [] self.matches = []
return super().on_validation_epoch_start(trainer, pl_module) return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore def on_validation_batch_end( # type: ignore
self, self,
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: TrainingModule,
outputs: ModelOutput, outputs: ModelOutput,
batch: TrainExample, batch: TrainExample,
batch_idx: int, batch_idx: int,
@ -32,24 +54,73 @@ class ValidationMetrics(Callback):
) -> None: ) -> None:
dataloaders = trainer.val_dataloaders dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader) assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset) assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
# clip_prediction = postprocess_model_outputs( clip_annotations = [
# outputs, _get_subclip(
# clips=[clip_annotation.clip], dataset.get_clip_annotation(example_id),
# classes=self.class_names, start_time=start_time.item(),
# decoder=self.decoder, end_time=end_time.item(),
# config=self.config.postprocessing, targets=pl_module.targets,
# )[0] )
# for example_id, start_time, end_time in zip(
# matches = match_predictions_and_annotations( batch.idx,
# clip_annotation, batch.start_time,
# clip_prediction, batch.end_time,
# ) )
# ]
# self.validation_predictions.extend(matches)
# return super().on_validation_batch_end( clips = [clip_annotation.clip for clip_annotation in clip_annotations]
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
# ) raw_predictions = pl_module.postprocessor.get_raw_predictions(
outputs,
clips,
)
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
):
self.matches.extend(
match_sound_events_and_raw_predictions(
sound_events=clip_annotation.sound_events,
raw_predictions=clip_predictions,
targets=pl_module.targets,
)
)
def _is_in_subclip(
sound_event_annotation: data.SoundEventAnnotation,
targets: TargetProtocol,
start_time: float,
end_time: float,
) -> bool:
time, _ = targets.get_position(sound_event_annotation)
return start_time <= time <= end_time
def _get_subclip(
clip_annotation: data.ClipAnnotation,
start_time: float,
end_time: float,
targets: TargetProtocol,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=data.Clip(
recording=clip_annotation.clip.recording,
start_time=start_time,
end_time=end_time,
),
sound_events=[
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if _is_in_subclip(
sound_event_annotation,
targets,
start_time=start_time,
end_time=end_time,
)
],
)

View File

@ -69,12 +69,15 @@ class Clipper(ClipperProtocol):
) )
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol: def build_clipper(
config: Optional[ClipingConfig] = None,
random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig() config = config or ClipingConfig()
return Clipper( return Clipper(
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,
random=config.random, random=config.random if random else False,
) )

View File

@ -1,7 +1,7 @@
from typing import Optional from typing import Optional, Union
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
@ -23,8 +23,29 @@ class OptimizerConfig(BaseConfig):
t_max: int = 100 t_max: int = 100
class TrainerConfig(BaseConfig):
accelerator: str = "auto"
accumulate_grad_batches: int = 1
deterministic: bool = True
check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = 200
min_epochs: Optional[int] = None
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
val_check_interval: Optional[Union[int, float]] = None
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
batch_size: int = 32 batch_size: int = 8
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
@ -36,9 +57,11 @@ class TrainingConfig(BaseConfig):
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
def load_train_config( def load_train_config(
path: PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
) -> TrainingConfig: ) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field) return load_config(path, schema=TrainingConfig, field=field)

View File

@ -42,8 +42,8 @@ class LabeledDataset(Dataset):
class_heatmap=self.to_tensor(dataset["class"]), class_heatmap=self.to_tensor(dataset["class"]),
size_heatmap=self.to_tensor(dataset["size"]), size_heatmap=self.to_tensor(dataset["size"]),
idx=torch.tensor(idx), idx=torch.tensor(idx),
start_time=start_time, start_time=torch.tensor(start_time),
end_time=end_time, end_time=torch.tensor(end_time),
) )
@classmethod @classmethod
@ -81,17 +81,24 @@ class LabeledDataset(Dataset):
array: xr.DataArray, array: xr.DataArray,
dtype=np.float32, dtype=np.float32,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.tensor(array.values.astype(dtype)) return torch.nan_to_num(
torch.tensor(array.values.astype(dtype)),
nan=0,
)
def list_preprocessed_files( def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc" directory: data.PathLike, extension: str = ".nc"
) -> Sequence[Path]: ) -> List[Path]:
return list(Path(directory).glob(f"*{extension}")) return list(Path(directory).glob(f"*{extension}"))
class RandomExampleSource: class RandomExampleSource:
def __init__(self, filenames: List[str], clipper: ClipperProtocol): def __init__(
self,
filenames: List[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames self.filenames = filenames
self.clipper = clipper self.clipper = clipper

View File

@ -23,13 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`. defined in `LabelConfig`.
""" """
import logging
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from loguru import logger
from scipy.ndimage import gaussian_filter from scipy.ndimage import gaussian_filter
from soundevent import arrays, data from soundevent import arrays, data
@ -52,8 +52,6 @@ __all__ = [
SIZE_DIMENSION = "dimension" SIZE_DIMENSION = "dimension"
"""Dimension name for the size heatmap.""" """Dimension name for the size heatmap."""
logger = logging.getLogger(__name__)
class LabelConfig(BaseConfig): class LabelConfig(BaseConfig):
"""Configuration parameters for heatmap generation. """Configuration parameters for heatmap generation.
@ -137,12 +135,27 @@ def generate_clip_label(
A NamedTuple containing the generated 'detection', 'classes', and 'size' A NamedTuple containing the generated 'detection', 'classes', and 'size'
heatmaps for this clip. heatmaps for this clip.
""" """
logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events)
)
sound_events = []
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps( return generate_heatmaps(
( sound_events,
targets.transform(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
),
spec=spec, spec=spec,
targets=targets, targets=targets,
target_sigma=config.sigma, target_sigma=config.sigma,

View File

@ -40,7 +40,9 @@ class TrainingModule(L.LightningModule):
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.t_max = t_max self.t_max = t_max
self.save_hyperparameters() # NOTE: Ignore detector and loss from hyperparameter saving
# as they are nn.Module and should be saved regardless.
self.save_hyperparameters(ignore=["detector", "loss"])
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec) return self.detector(spec)
@ -49,21 +51,25 @@ class TrainingModule(L.LightningModule):
outputs = self.forward(batch.spec) outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("train/loss/total", losses.total, prog_bar=True, logger=True) self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("train/loss/detection", losses.total, logger=True) self.log("detection_loss/train", losses.total, logger=True)
self.log("train/loss/size", losses.total, logger=True) self.log("size_loss/train", losses.total, logger=True)
self.log("train/loss/classification", losses.total, logger=True) self.log("classification_loss/train", losses.total, logger=True)
return losses.total return losses.total
def validation_step(self, batch: TrainExample, batch_idx: int) -> None: def validation_step( # type: ignore
self, batch: TrainExample, batch_idx: int
) -> ModelOutput:
outputs = self.forward(batch.spec) outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("val/loss/total", losses.total, prog_bar=True, logger=True) self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("val/loss/detection", losses.total, logger=True) self.log("detection_loss/val", losses.total, logger=True)
self.log("val/loss/size", losses.total, logger=True) self.log("size_loss/val", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True) self.log("classification_loss/val", losses.total, logger=True)
return outputs
def configure_optimizers(self): def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate) optimizer = Adam(self.parameters(), lr=self.learning_rate)

View File

@ -1,68 +1,147 @@
from typing import Optional, Union from typing import List, Optional
from lightning import LightningModule from lightning import Trainer
from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback
from soundevent.data import PathLike from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig, load_config from batdetect2.models.types import DetectionModel
from batdetect2.train.dataset import LabeledDataset from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.augmentations import (
build_augmentations,
)
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"train", "train",
"TrainerConfig", "build_val_dataset",
"load_trainer_config", "build_train_dataset",
] ]
class TrainerConfig(BaseConfig):
accelerator: str = "auto"
accumulate_grad_batches: int = 1
deterministic: bool = True
check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = None
min_epochs: Optional[int] = 100
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
reload_dataloaders_every_n_epochs: Optional[int] = None
val_check_interval: Optional[Union[int, float]] = None
def load_trainer_config(path: PathLike, field: Optional[str] = None):
return load_config(path, schema=TrainerConfig, field=field)
def train( def train(
module: LightningModule, detector: DetectionModel,
train_dataset: LabeledDataset, train_examples: List[data.PathLike],
trainer_config: Optional[TrainerConfig] = None, targets: Optional[TargetProtocol] = None,
dev_run: bool = False, preprocessor: Optional[PreprocessorProtocol] = None,
overfit_batches: bool = False, postprocessor: Optional[PostprocessorProtocol] = None,
profiler: Optional[str] = None, val_examples: Optional[List[data.PathLike]] = None,
): config: Optional[TrainingConfig] = None,
trainer_config = trainer_config or TrainerConfig() callbacks: Optional[List[Callback]] = None,
model_path: Optional[data.PathLike] = None,
train_workers: int = 0,
val_workers: int = 0,
**trainer_kwargs,
) -> None:
config = config or TrainingConfig()
if model_path is None:
if preprocessor is None:
preprocessor = build_preprocessor()
if targets is None:
targets = build_targets()
if postprocessor is None:
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
loss = build_loss(config.loss)
module = TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=config.optimizer.learning_rate,
t_max=config.optimizer.t_max,
)
else:
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
train_dataset = build_train_dataset(
train_examples,
preprocessor=module.preprocessor,
config=config,
)
trainer = Trainer( trainer = Trainer(
**trainer_config.model_dump( **config.trainer.model_dump(exclude_none=True),
exclude_unset=True, callbacks=callbacks,
exclude_none=True, **trainer_kwargs,
),
fast_dev_run=dev_run,
overfit_batches=overfit_batches,
profiler=profiler,
) )
train_loader = DataLoader(
train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_size=module.config.train.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,
num_workers=7, num_workers=train_workers,
) )
trainer.fit(module, train_dataloaders=train_loader)
val_dataloader = None
if val_examples:
val_dataset = build_val_dataset(
val_examples,
config=config,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=val_workers,
)
trainer.fit(
module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
def build_train_dataset(
examples: List[data.PathLike],
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
) -> LabeledDataset:
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=True)
random_example_source = RandomExampleSource(
examples,
clipper=clipper,
)
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
return LabeledDataset(
examples,
clipper=clipper,
augmentation=augmentations,
)
def build_val_dataset(
examples: List[data.PathLike],
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset:
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
return LabeledDataset(examples, clipper=clipper)

View File

@ -57,8 +57,8 @@ class TrainExample(NamedTuple):
class_heatmap: torch.Tensor class_heatmap: torch.Tensor
size_heatmap: torch.Tensor size_heatmap: torch.Tensor
idx: torch.Tensor idx: torch.Tensor
start_time: float start_time: torch.Tensor
end_time: float end_time: torch.Tensor
class Losses(NamedTuple): class Losses(NamedTuple):

View File

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import xarray as xr
def extend_width( def extend_width(
@ -59,3 +60,10 @@ def adjust_width(
for index in range(dims) for index in range(dims)
] ]
return array[tuple(slices)] return array[tuple(slices)]
def iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)

View File

@ -7,6 +7,8 @@ import pytest
import soundfile as sf import soundfile as sf
from soundevent import data, terms from soundevent import data, terms
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import ( from batdetect2.targets import (
@ -383,3 +385,27 @@ def sample_labeller(
sample_targets: TargetProtocol, sample_targets: TargetProtocol,
) -> ClipLabeller: ) -> ClipLabeller:
return build_clip_labeler(sample_targets) return build_clip_labeler(sample_targets)
@pytest.fixture
def example_dataset(example_data_dir: Path) -> DatasetConfig:
return DatasetConfig(
name="test dataset",
description="test dataset",
sources=[
BatDetect2FilesAnnotations(
name="example annotations",
audio_dir=example_data_dir / "audio",
annotations_dir=example_data_dir / "anns",
)
],
)
@pytest.fixture
def example_annotations(
example_dataset: DatasetConfig,
) -> List[data.ClipAnnotation]:
annotations = load_dataset(example_dataset)
assert len(annotations) == 3
return annotations

View File

@ -254,11 +254,11 @@ class TestLoadBatDetect2Files:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
assert clip_ann.notes[0].message == "Standard notes." assert clip_ann.notes[0].message == "Standard notes."
clip_tag = data.find_tag(clip_ann.tags, "class") clip_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_tag is not None assert clip_tag is not None
assert clip_tag.value == "Myotis" assert clip_tag.value == "Myotis"
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "class") recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
assert recording_tag is not None assert recording_tag is not None
assert recording_tag.value == "Myotis" assert recording_tag.value == "Myotis"
@ -271,15 +271,15 @@ class TestLoadBatDetect2Files:
40000, 40000,
] ]
se_class_tag = data.find_tag(se_ann.tags, "class") se_class_tag = data.find_tag(se_ann.tags, "Class")
assert se_class_tag is not None assert se_class_tag is not None
assert se_class_tag.value == "Myotis" assert se_class_tag.value == "Myotis"
se_event_tag = data.find_tag(se_ann.tags, "event") se_event_tag = data.find_tag(se_ann.tags, "Call Type")
assert se_event_tag is not None assert se_event_tag is not None
assert se_event_tag.value == "Echolocation" assert se_event_tag.value == "Echolocation"
se_individual_tag = data.find_tag(se_ann.tags, "individual") se_individual_tag = data.find_tag(se_ann.tags, "Individual")
assert se_individual_tag is not None assert se_individual_tag is not None
assert se_individual_tag.value == "0" assert se_individual_tag.value == "0"
@ -439,7 +439,7 @@ class TestLoadBatDetect2Merged:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
clip_class_tag = data.find_tag(clip_ann.tags, "class") clip_class_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_class_tag is not None assert clip_class_tag is not None
assert clip_class_tag.value == "Myotis" assert clip_class_tag.value == "Myotis"

View File

@ -98,7 +98,7 @@ def sample_detection_dataset() -> xr.Dataset:
expected_freqs = np.array([300, 200]) expected_freqs = np.array([300, 200])
detection_coords = { detection_coords = {
"time": ("detection", expected_times), "time": ("detection", expected_times),
"freq": ("detection", expected_freqs), "frequency": ("detection", expected_freqs),
} }
scores_data = np.array([0.9, 0.8], dtype=np.float64) scores_data = np.array([0.9, 0.8], dtype=np.float64)
@ -106,7 +106,7 @@ def sample_detection_dataset() -> xr.Dataset:
scores_data, scores_data,
coords=detection_coords, coords=detection_coords,
dims=["detection"], dims=["detection"],
name="scores", name="score",
) )
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32) dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
@ -183,7 +183,7 @@ def empty_detection_dataset() -> xr.Dataset:
) )
return xr.Dataset( return xr.Dataset(
{ {
"scores": scores, "score": scores,
"dimensions": dimensions, "dimensions": dimensions,
"classes": classes, "classes": classes,
"features": features, "features": features,
@ -206,10 +206,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
) )
pred1 = RawPrediction( pred1 = RawPrediction(
detection_score=0.9, detection_score=0.9,
start_time=20 - 7 / 2, geometry=data.BoundingBox(
end_time=20 + 7 / 2, coordinates=[
low_freq=300 - 16 / 2, 20 - 7 / 2,
high_freq=300 + 16 / 2, 300 - 16 / 2,
20 + 7 / 2,
300 + 16 / 2,
]
),
class_scores=pred1_classes, class_scores=pred1_classes,
features=pred1_features, features=pred1_features,
) )
@ -224,10 +228,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
) )
pred2 = RawPrediction( pred2 = RawPrediction(
detection_score=0.8, detection_score=0.8,
start_time=10 - 3 / 2, geometry=data.BoundingBox(
end_time=10 + 3 / 2, coordinates=[
low_freq=200 - 12 / 2, 10 - 3 / 2,
high_freq=200 + 12 / 2, 200 - 12 / 2,
10 + 3 / 2,
200 + 12 / 2,
]
),
class_scores=pred2_classes, class_scores=pred2_classes,
features=pred2_features, features=pred2_features,
) )
@ -242,10 +250,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
) )
pred3 = RawPrediction( pred3 = RawPrediction(
detection_score=0.15, detection_score=0.15,
start_time=5.0, geometry=data.BoundingBox(
end_time=6.0, coordinates=[
low_freq=50.0, 5.0,
high_freq=60.0, 50.0,
6.0,
60.0,
]
),
class_scores=pred3_classes, class_scores=pred3_classes,
features=pred3_features, features=pred3_features,
) )
@ -267,10 +279,12 @@ def test_convert_xr_dataset_basic(
assert isinstance(pred1, RawPrediction) assert isinstance(pred1, RawPrediction)
assert pred1.detection_score == 0.9 assert pred1.detection_score == 0.9
assert pred1.start_time == 20 - 7 / 2 assert pred1.geometry.coordinates == [
assert pred1.end_time == 20 + 7 / 2 20 - 7 / 2,
assert pred1.low_freq == 300 - 16 / 2 300 - 16 / 2,
assert pred1.high_freq == 300 + 16 / 2 20 + 7 / 2,
300 + 16 / 2,
]
xr.testing.assert_allclose( xr.testing.assert_allclose(
pred1.class_scores, pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0), sample_detection_dataset["classes"].sel(detection=0),
@ -283,10 +297,12 @@ def test_convert_xr_dataset_basic(
assert isinstance(pred2, RawPrediction) assert isinstance(pred2, RawPrediction)
assert pred2.detection_score == 0.8 assert pred2.detection_score == 0.8
assert pred2.start_time == 10 - 3 / 2 assert pred2.geometry.coordinates == [
assert pred2.end_time == 10 + 3 / 2 10 - 3 / 2,
assert pred2.low_freq == 200 - 12 / 2 200 - 12 / 2,
assert pred2.high_freq == 200 + 12 / 2 10 + 3 / 2,
200 + 12 / 2,
]
xr.testing.assert_allclose( xr.testing.assert_allclose(
pred2.class_scores, pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1), sample_detection_dataset["classes"].sel(detection=1),
@ -331,15 +347,7 @@ def test_convert_raw_to_sound_event_basic(
assert isinstance(se, data.SoundEvent) assert isinstance(se, data.SoundEvent)
assert se.recording == sample_recording assert se.recording == sample_recording
assert isinstance(se.geometry, data.BoundingBox) assert isinstance(se.geometry, data.BoundingBox)
np.testing.assert_allclose( assert se.geometry == raw_pred.geometry
se.geometry.coordinates,
[
raw_pred.start_time,
raw_pred.low_freq,
raw_pred.end_time,
raw_pred.high_freq,
],
)
assert len(se.features) == len(raw_pred.features) assert len(se.features) == len(raw_pred.features)
feat_dict = {f.term.name: f.value for f in se.features} feat_dict = {f.term.name: f.value for f in se.features}

View File

@ -1,6 +1,6 @@
import math import math
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, Union
import numpy as np import numpy as np
import pytest import pytest
@ -307,7 +307,7 @@ def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
def test_resize_spectrogram( def test_resize_spectrogram(
sample_spec: xr.DataArray, sample_spec: xr.DataArray,
height: int, height: int,
resize_factor: float | None, resize_factor: Union[float, None],
expected_freq_size: int, expected_freq_size: int,
expected_time_factor: float, expected_time_factor: float,
): ):

View File

@ -4,6 +4,7 @@ from typing import Callable, List, Set
import pytest import pytest
from soundevent import data from soundevent import data
from batdetect2.targets import build_targets
from batdetect2.targets.filtering import ( from batdetect2.targets.filtering import (
FilterConfig, FilterConfig,
FilterRule, FilterRule,
@ -176,3 +177,34 @@ rules:
filter_result = load_filter_from_config(test_config_path) filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"]) annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation) is False assert filter_result(annotation) is False
def test_default_filtering_over_example_dataset(
example_annotations: List[data.ClipAnnotation],
):
targets = build_targets()
clip1 = example_annotations[0]
clip2 = example_annotations[1]
clip3 = example_annotations[2]
assert (
sum(
[targets.filter(sound_event) for sound_event in clip1.sound_events]
)
== 9
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip2.sound_events]
)
== 15
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip3.sound_events]
)
== 20
)

View File

@ -9,8 +9,8 @@ from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
add_echo, add_echo,
mix_examples, mix_examples,
select_subclip,
) )
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
from batdetect2.train.types import ClipLabeller from batdetect2.train.types import ClipLabeller
@ -121,7 +121,7 @@ def test_selected_random_subclip_has_the_correct_width(
preprocessor=sample_preprocessor, preprocessor=sample_preprocessor,
labeller=sample_labeller, labeller=sample_labeller,
) )
subclip = select_subclip(original, width=100) subclip = select_subclip(original, start=0, span=100)
assert subclip["spectrogram"].shape[1] == 100 assert subclip["spectrogram"].shape[1] == 100
@ -142,7 +142,7 @@ def test_add_echo_after_subclip(
assert original.sizes["time"] > 512 assert original.sizes["time"] > 512
subclip = select_subclip(original, width=512) subclip = select_subclip(original, start=0, span=512)
with_echo = add_echo(subclip, preprocessor=sample_preprocessor) with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
assert with_echo.sizes["time"] == 512 assert with_echo.sizes["time"] == 512