mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Restructuring
This commit is contained in:
parent
60e922d565
commit
7d6cba5465
@ -1,10 +1,15 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2 import api
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
|
||||||
from batdetect2.types import ProcessingConfiguration
|
DEFAULT_MODEL_PATH = os.path.join(
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"models",
|
||||||
|
"checkpoints",
|
||||||
|
"Net2DFast_UK_same.pth.tar",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -74,6 +79,9 @@ def detect(
|
|||||||
|
|
||||||
Input files should be short in duration e.g. < 30 seconds.
|
Input files should be short in duration e.g. < 30 seconds.
|
||||||
"""
|
"""
|
||||||
|
from batdetect2 import api
|
||||||
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
|
|
||||||
click.echo(f"Loading model: {args['model_path']}")
|
click.echo(f"Loading model: {args['model_path']}")
|
||||||
model, params = api.load_model(args["model_path"])
|
model, params = api.load_model(args["model_path"])
|
||||||
|
|
||||||
@ -123,7 +131,7 @@ def detect(
|
|||||||
click.echo(f" {err}")
|
click.echo(f" {err}")
|
||||||
|
|
||||||
|
|
||||||
def print_config(config: ProcessingConfiguration):
|
def print_config(config):
|
||||||
"""Print the processing configuration."""
|
"""Print the processing configuration."""
|
||||||
click.echo("\nProcessing Configuration:")
|
click.echo("\nProcessing Configuration:")
|
||||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import Optional
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
|
|
||||||
__all__ = ["data"]
|
__all__ = ["data"]
|
||||||
|
|
||||||
@ -33,6 +32,8 @@ def summary(
|
|||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
|
||||||
base_dir = base_dir or Path.cwd()
|
base_dir = base_dir or Path.cwd()
|
||||||
dataset = load_dataset_from_config(
|
dataset = load_dataset_from_config(
|
||||||
dataset_config,
|
dataset_config,
|
||||||
|
|||||||
@ -6,9 +6,6 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.evaluate.evaluate import evaluate
|
|
||||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
|
||||||
|
|
||||||
__all__ = ["evaluate_command"]
|
__all__ = ["evaluate_command"]
|
||||||
|
|
||||||
@ -31,6 +28,10 @@ def evaluate_command(
|
|||||||
workers: Optional[int] = None,
|
workers: Optional[int] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.evaluate.evaluate import evaluate
|
||||||
|
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||||
|
|
||||||
logger.remove()
|
logger.remove()
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
log_level = "WARNING"
|
log_level = "WARNING"
|
||||||
|
|||||||
@ -6,13 +6,6 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.targets import load_target_config
|
|
||||||
from batdetect2.train import (
|
|
||||||
FullTrainingConfig,
|
|
||||||
load_full_training_config,
|
|
||||||
train,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["train_command"]
|
__all__ = ["train_command"]
|
||||||
|
|
||||||
@ -53,6 +46,14 @@ def train_command(
|
|||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.targets import load_target_config
|
||||||
|
from batdetect2.train import (
|
||||||
|
FullTrainingConfig,
|
||||||
|
load_full_training_config,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
|
||||||
logger.remove()
|
logger.remove()
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
log_level = "WARNING"
|
log_level = "WARNING"
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.types import ClassMapper
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
from batdetect2.targets.terms import get_term_from_key
|
|
||||||
from batdetect2.types import (
|
from batdetect2.types import (
|
||||||
Annotation,
|
Annotation,
|
||||||
AudioLoaderAnnotationGroup,
|
AudioLoaderAnnotationGroup,
|
||||||
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(key=label_key, value=annotation["class"]),
|
||||||
term=get_term_from_key(label_key),
|
data.Tag(key=event_key, value=annotation["event"]),
|
||||||
value=annotation["class"],
|
data.Tag(key=individual_key, value=str(annotation["individual"])),
|
||||||
),
|
|
||||||
data.Tag(
|
|
||||||
term=get_term_from_key(event_key),
|
|
||||||
value=annotation["event"],
|
|
||||||
),
|
|
||||||
data.Tag(
|
|
||||||
term=get_term_from_key(individual_key),
|
|
||||||
value=str(annotation["individual"]),
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
|
|||||||
tags=[
|
tags=[
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
score=annotation["class_prob"],
|
score=annotation["class_prob"],
|
||||||
tag=data.Tag(
|
tag=data.Tag(key=label_key, value=annotation["class"]),
|
||||||
term=get_term_from_key(label_key),
|
|
||||||
value=annotation["class"],
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
score=annotation["det_prob"],
|
score=annotation["det_prob"],
|
||||||
tag=data.Tag(
|
tag=data.Tag(key=event_key, value=annotation["event"]),
|
||||||
term=get_term_from_key(event_key),
|
|
||||||
value=annotation["event"],
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
16
src/batdetect2/config.py
Normal file
16
src/batdetect2/config.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
|
from batdetect2.models.backbones import BackboneConfig
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig
|
||||||
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2Config(BaseConfig):
|
||||||
|
config_version: Literal["v1"] = "v1"
|
||||||
|
|
||||||
|
train: TrainingConfig
|
||||||
|
evaluation: EvaluationConfig
|
||||||
|
model: BackboneConfig
|
||||||
|
preprocess: PreprocessingConfig
|
||||||
8
src/batdetect2/core/__init__.py
Normal file
8
src/batdetect2/core/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseConfig",
|
||||||
|
"load_config",
|
||||||
|
"Registry",
|
||||||
|
]
|
||||||
@ -1,7 +1,12 @@
|
|||||||
|
import sys
|
||||||
from typing import Generic, Protocol, Type, TypeVar
|
from typing import Generic, Protocol, Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from typing import ParamSpec
|
||||||
|
else:
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Registry",
|
"Registry",
|
||||||
@ -18,7 +18,7 @@ from uuid import uuid5
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from loguru import logger
|
|||||||
from pydantic import Field, ValidationError
|
from pydantic import Field, ValidationError
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.legacy import (
|
from batdetect2.data.annotations.legacy import (
|
||||||
FileAnnotation,
|
FileAnnotation,
|
||||||
file_annotation_to_clip,
|
file_annotation_to_clip,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from loguru import logger
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.annotations import (
|
from batdetect2.data.annotations import (
|
||||||
AnnotatedDataset,
|
AnnotatedDataset,
|
||||||
AnnotationFormats,
|
AnnotationFormats,
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import (
|
||||||
SoundEventCondition,
|
SoundEventCondition,
|
||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.typing.evaluate import AffinityFunction
|
from batdetect2.typing.evaluate import AffinityFunction
|
||||||
|
|
||||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.evaluate.metrics import (
|
||||||
ClassificationAPConfig,
|
ClassificationAPConfig,
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity
|
|||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.evaluate.affinity import (
|
from batdetect2.evaluate.affinity import (
|
||||||
AffinityConfig,
|
AffinityConfig,
|
||||||
GeometricIOUConfig,
|
GeometricIOUConfig,
|
||||||
|
|||||||
@ -15,8 +15,8 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from sklearn.preprocessing import label_binarize
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.typing import MetricsProtocol
|
from batdetect2.typing import MetricsProtocol
|
||||||
from batdetect2.typing.evaluate import ClipEvaluation
|
from batdetect2.typing.evaluate import ClipEvaluation
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
|
|||||||
|
|
||||||
class DetectionAPConfig(BaseConfig):
|
class DetectionAPConfig(BaseConfig):
|
||||||
name: Literal["detection_ap"] = "detection_ap"
|
name: Literal["detection_ap"] = "detection_ap"
|
||||||
implementation: AveragePrecisionImplementation = "pascal_voc"
|
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||||
@ -96,7 +96,7 @@ class DetectionAP(MetricsProtocol):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
||||||
return cls(implementation=config.implementation)
|
return cls(implementation=config.ap_implementation)
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||||
@ -104,6 +104,7 @@ metrics_registry.register(DetectionAPConfig, DetectionAP)
|
|||||||
|
|
||||||
class ClassificationAPConfig(BaseConfig):
|
class ClassificationAPConfig(BaseConfig):
|
||||||
name: Literal["classification_ap"] = "classification_ap"
|
name: Literal["classification_ap"] = "classification_ap"
|
||||||
|
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||||
include: Optional[List[str]] = None
|
include: Optional[List[str]] = None
|
||||||
exclude: Optional[List[str]] = None
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
@ -193,6 +194,7 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
class_names,
|
class_names,
|
||||||
|
implementation=config.ap_implementation,
|
||||||
include=config.include,
|
include=config.include,
|
||||||
exclude=config.exclude,
|
exclude=config.exclude,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
|||||||
@ -32,7 +32,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Backbone,
|
Backbone,
|
||||||
BackboneConfig,
|
BackboneConfig,
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import torch.nn.functional as F
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.bottleneck import (
|
from batdetect2.models.bottleneck import (
|
||||||
DEFAULT_BOTTLENECK_CONFIG,
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
BottleneckConfig,
|
BottleneckConfig,
|
||||||
|
|||||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
VerticalConv,
|
VerticalConv,
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
|
|||||||
@ -8,8 +8,7 @@ from soundevent.plot.tags import TagColorMapper
|
|||||||
|
|
||||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||||
from batdetect2.preprocess import PreprocessorProtocol
|
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
||||||
from batdetect2.typing.evaluate import MatchEvaluation
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_matches",
|
"plot_matches",
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from loguru import logger
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.decoding import (
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
|
|||||||
@ -1,176 +1,21 @@
|
|||||||
"""Main entry point for the BatDetect2 Preprocessing subsystem.
|
"""Main entry point for the BatDetect2 preprocessing subsystem."""
|
||||||
|
|
||||||
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
for converting raw audio input (from files or data objects) into processed
|
from batdetect2.preprocess.config import (
|
||||||
spectrograms suitable for input to BatDetect2 models. This ensures consistent
|
|
||||||
data handling between model training and inference.
|
|
||||||
|
|
||||||
The preprocessing pipeline consists of two main stages, configured via nested
|
|
||||||
data structures:
|
|
||||||
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
|
|
||||||
processing like resampling, duration adjustment, centering, and scaling.
|
|
||||||
Configured via `AudioConfig`.
|
|
||||||
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
|
|
||||||
the processed waveform using STFT, followed by frequency cropping, optional
|
|
||||||
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
|
|
||||||
resizing, and optional peak normalization. Configured via
|
|
||||||
`SpectrogramConfig`.
|
|
||||||
|
|
||||||
This module provides the primary interface:
|
|
||||||
|
|
||||||
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
|
|
||||||
and `SpectrogramConfig`.
|
|
||||||
- `load_preprocessing_config`: Function to load the unified configuration.
|
|
||||||
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
|
|
||||||
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
|
|
||||||
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
|
|
||||||
instance from a `PreprocessingConfig`.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.preprocess.audio import (
|
|
||||||
DEFAULT_DURATION,
|
|
||||||
SCALE_RAW_AUDIO,
|
|
||||||
TARGET_SAMPLERATE_HZ,
|
|
||||||
AudioConfig,
|
|
||||||
ResampleConfig,
|
|
||||||
build_audio_loader,
|
|
||||||
build_audio_pipeline,
|
|
||||||
)
|
|
||||||
from batdetect2.preprocess.spectrogram import (
|
|
||||||
MAX_FREQ,
|
MAX_FREQ,
|
||||||
MIN_FREQ,
|
MIN_FREQ,
|
||||||
FrequencyConfig,
|
TARGET_SAMPLERATE_HZ,
|
||||||
PcenConfig,
|
PreprocessingConfig,
|
||||||
SpectrogramConfig,
|
load_preprocessing_config,
|
||||||
SpectrogramPipeline,
|
|
||||||
STFTConfig,
|
|
||||||
_spec_params_from_config,
|
|
||||||
build_spectrogram_builder,
|
|
||||||
build_spectrogram_pipeline,
|
|
||||||
)
|
)
|
||||||
from batdetect2.typing import PreprocessorProtocol
|
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AudioConfig",
|
|
||||||
"DEFAULT_DURATION",
|
|
||||||
"FrequencyConfig",
|
|
||||||
"MAX_FREQ",
|
|
||||||
"MIN_FREQ",
|
"MIN_FREQ",
|
||||||
"PcenConfig",
|
"MAX_FREQ",
|
||||||
"PreprocessingConfig",
|
|
||||||
"ResampleConfig",
|
|
||||||
"SCALE_RAW_AUDIO",
|
|
||||||
"STFTConfig",
|
|
||||||
"SpectrogramConfig",
|
|
||||||
"TARGET_SAMPLERATE_HZ",
|
"TARGET_SAMPLERATE_HZ",
|
||||||
"build_audio_loader",
|
"PreprocessingConfig",
|
||||||
"build_spectrogram_builder",
|
|
||||||
"load_preprocessing_config",
|
"load_preprocessing_config",
|
||||||
|
"build_preprocessor",
|
||||||
|
"build_audio_loader",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseConfig):
|
|
||||||
"""Unified configuration for the audio preprocessing pipeline.
|
|
||||||
|
|
||||||
Aggregates the configuration for both the initial audio processing stage
|
|
||||||
and the subsequent spectrogram generation stage.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
audio : AudioConfig
|
|
||||||
Configuration settings for the audio loading and initial waveform
|
|
||||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
|
||||||
Defaults to default `AudioConfig` settings if omitted.
|
|
||||||
spectrogram : SpectrogramConfig
|
|
||||||
Configuration settings for the spectrogram generation process
|
|
||||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
|
||||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_preprocessing_config(
|
|
||||||
path: PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> PreprocessingConfig:
|
|
||||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
|
||||||
"""Standard implementation of the `Preprocessor` protocol."""
|
|
||||||
|
|
||||||
input_samplerate: int
|
|
||||||
output_samplerate: float
|
|
||||||
|
|
||||||
max_freq: float
|
|
||||||
min_freq: float
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
audio_pipeline: torch.nn.Module,
|
|
||||||
spectrogram_pipeline: SpectrogramPipeline,
|
|
||||||
input_samplerate: int,
|
|
||||||
output_samplerate: float,
|
|
||||||
max_freq: float,
|
|
||||||
min_freq: float,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.audio_pipeline = audio_pipeline
|
|
||||||
self.spectrogram_pipeline = spectrogram_pipeline
|
|
||||||
|
|
||||||
self.max_freq = max_freq
|
|
||||||
self.min_freq = min_freq
|
|
||||||
|
|
||||||
self.input_samplerate = input_samplerate
|
|
||||||
self.output_samplerate = output_samplerate
|
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
||||||
wav = self.audio_pipeline(wav)
|
|
||||||
return self.spectrogram_pipeline(wav)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
|
||||||
samplerate = config.audio.samplerate
|
|
||||||
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
|
||||||
factor = config.spectrogram.size.resize_factor
|
|
||||||
return samplerate * factor / hop_size
|
|
||||||
|
|
||||||
|
|
||||||
def build_preprocessor(
|
|
||||||
config: Optional[PreprocessingConfig] = None,
|
|
||||||
) -> PreprocessorProtocol:
|
|
||||||
"""Factory function to build the standard preprocessor from configuration."""
|
|
||||||
config = config or PreprocessingConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building preprocessor with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
samplerate = config.audio.samplerate
|
|
||||||
|
|
||||||
min_freq = config.spectrogram.frequencies.min_freq
|
|
||||||
max_freq = config.spectrogram.frequencies.max_freq
|
|
||||||
|
|
||||||
output_samplerate = compute_output_samplerate(config)
|
|
||||||
|
|
||||||
return StandardPreprocessor(
|
|
||||||
audio_pipeline=build_audio_pipeline(config.audio),
|
|
||||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
|
||||||
samplerate, config.spectrogram
|
|
||||||
),
|
|
||||||
input_samplerate=samplerate,
|
|
||||||
output_samplerate=output_samplerate,
|
|
||||||
min_freq=min_freq,
|
|
||||||
max_freq=max_freq,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,64 +1,34 @@
|
|||||||
"""Handles loading and initial preprocessing of audio waveforms."""
|
"""Handles loading and initial preprocessing of audio waveforms."""
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
from pydantic import Field
|
|
||||||
from scipy.signal import resample, resample_poly
|
from scipy.signal import resample, resample_poly
|
||||||
from soundevent import audio, data
|
from soundevent import audio, data
|
||||||
from soundfile import LibsndfileError
|
from soundfile import LibsndfileError
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
|
||||||
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
||||||
|
from batdetect2.preprocess.config import (
|
||||||
|
TARGET_SAMPLERATE_HZ,
|
||||||
|
AudioConfig,
|
||||||
|
AudioTransform,
|
||||||
|
ResampleConfig,
|
||||||
|
)
|
||||||
from batdetect2.typing import AudioLoader
|
from batdetect2.typing import AudioLoader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ResampleConfig",
|
|
||||||
"AudioConfig",
|
|
||||||
"SoundEventAudioLoader",
|
"SoundEventAudioLoader",
|
||||||
"build_audio_loader",
|
"build_audio_loader",
|
||||||
"load_file_audio",
|
"load_file_audio",
|
||||||
"load_recording_audio",
|
"load_recording_audio",
|
||||||
"load_clip_audio",
|
"load_clip_audio",
|
||||||
"resample_audio",
|
"resample_audio",
|
||||||
"TARGET_SAMPLERATE_HZ",
|
|
||||||
"SCALE_RAW_AUDIO",
|
|
||||||
"DEFAULT_DURATION",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256_000
|
|
||||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
|
||||||
|
|
||||||
SCALE_RAW_AUDIO = False
|
class SoundEventAudioLoader(AudioLoader):
|
||||||
"""Default setting for whether to perform peak normalization."""
|
|
||||||
|
|
||||||
DEFAULT_DURATION = None
|
|
||||||
"""Default setting for target audio duration in seconds."""
|
|
||||||
|
|
||||||
|
|
||||||
class ResampleConfig(BaseConfig):
|
|
||||||
"""Configuration for audio resampling.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
samplerate : int, default=256000
|
|
||||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
|
||||||
method : str, default="poly"
|
|
||||||
The resampling algorithm to use. Options:
|
|
||||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
|
||||||
Generally fast.
|
|
||||||
- "fourier": Resampling via Fourier method using
|
|
||||||
`scipy.signal.resample`. May handle non-integer
|
|
||||||
resampling factors differently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
enabled: bool = True
|
|
||||||
method: str = "poly"
|
|
||||||
|
|
||||||
|
|
||||||
class SoundEventAudioLoader:
|
|
||||||
"""Concrete implementation of the `AudioLoader`."""
|
"""Concrete implementation of the `AudioLoader`."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -294,19 +264,6 @@ def resample_audio_fourier(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CenterAudioConfig(BaseConfig):
|
|
||||||
name: Literal["center_audio"] = "center_audio"
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleAudioConfig(BaseConfig):
|
|
||||||
name: Literal["scale_audio"] = "scale_audio"
|
|
||||||
|
|
||||||
|
|
||||||
class FixDurationConfig(BaseConfig):
|
|
||||||
name: Literal["fix_duration"] = "fix_duration"
|
|
||||||
duration: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class FixDuration(torch.nn.Module):
|
class FixDuration(torch.nn.Module):
|
||||||
def __init__(self, samplerate: int, duration: float):
|
def __init__(self, samplerate: int, duration: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -326,24 +283,6 @@ class FixDuration(torch.nn.Module):
|
|||||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||||
|
|
||||||
|
|
||||||
AudioTransform = Annotated[
|
|
||||||
Union[
|
|
||||||
FixDurationConfig,
|
|
||||||
ScaleAudioConfig,
|
|
||||||
CenterAudioConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class AudioConfig(BaseConfig):
|
|
||||||
"""Configuration for loading and initial audio preprocessing."""
|
|
||||||
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
|
||||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
|
||||||
transforms: List[AudioTransform] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def build_audio_loader(
|
def build_audio_loader(
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
) -> AudioLoader:
|
) -> AudioLoader:
|
||||||
|
|||||||
212
src/batdetect2/preprocess/config.py
Normal file
212
src/batdetect2/preprocess/config.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_preprocessing_config",
|
||||||
|
"CenterAudioConfig",
|
||||||
|
"ScaleAudioConfig",
|
||||||
|
"FixDurationConfig",
|
||||||
|
"ResampleConfig",
|
||||||
|
"AudioTransform",
|
||||||
|
"AudioConfig",
|
||||||
|
"STFTConfig",
|
||||||
|
"FrequencyConfig",
|
||||||
|
"PcenConfig",
|
||||||
|
"ScaleAmplitudeConfig",
|
||||||
|
"SpectralMeanSubstractionConfig",
|
||||||
|
"ResizeConfig",
|
||||||
|
"PeakNormalizeConfig",
|
||||||
|
"SpectrogramTransform",
|
||||||
|
"SpectrogramConfig",
|
||||||
|
"PreprocessingConfig",
|
||||||
|
"TARGET_SAMPLERATE_HZ",
|
||||||
|
"MIN_FREQ",
|
||||||
|
"MAX_FREQ",
|
||||||
|
]
|
||||||
|
|
||||||
|
TARGET_SAMPLERATE_HZ = 256_000
|
||||||
|
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||||
|
|
||||||
|
MIN_FREQ = 10_000
|
||||||
|
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
|
||||||
|
|
||||||
|
MAX_FREQ = 120_000
|
||||||
|
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
|
||||||
|
|
||||||
|
|
||||||
|
class CenterAudioConfig(BaseConfig):
|
||||||
|
name: Literal["center_audio"] = "center_audio"
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleAudioConfig(BaseConfig):
|
||||||
|
name: Literal["scale_audio"] = "scale_audio"
|
||||||
|
|
||||||
|
|
||||||
|
class FixDurationConfig(BaseConfig):
|
||||||
|
name: Literal["fix_duration"] = "fix_duration"
|
||||||
|
duration: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleConfig(BaseConfig):
|
||||||
|
"""Configuration for audio resampling.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
samplerate : int, default=256000
|
||||||
|
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||||
|
method : str, default="poly"
|
||||||
|
The resampling algorithm to use. Options:
|
||||||
|
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||||
|
Generally fast.
|
||||||
|
- "fourier": Resampling via Fourier method using
|
||||||
|
`scipy.signal.resample`. May handle non-integer
|
||||||
|
resampling factors differently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
method: str = "poly"
|
||||||
|
|
||||||
|
|
||||||
|
AudioTransform = Annotated[
|
||||||
|
Union[
|
||||||
|
FixDurationConfig,
|
||||||
|
ScaleAudioConfig,
|
||||||
|
CenterAudioConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioConfig(BaseConfig):
|
||||||
|
"""Configuration for loading and initial audio preprocessing."""
|
||||||
|
|
||||||
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||||
|
transforms: List[AudioTransform] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class STFTConfig(BaseConfig):
|
||||||
|
"""Configuration for the Short-Time Fourier Transform (STFT).
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
window_duration : float, default=0.002
|
||||||
|
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
||||||
|
> 0. Determines frequency resolution (longer window = finer frequency
|
||||||
|
resolution).
|
||||||
|
window_overlap : float, default=0.75
|
||||||
|
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
||||||
|
for 75%). Must be >= 0 and < 1. Determines time resolution
|
||||||
|
(higher overlap = finer time resolution).
|
||||||
|
window_fn : str, default="hann"
|
||||||
|
Name of the window function to apply before FFT calculation. Common
|
||||||
|
options include "hann", "hamming", "blackman". See
|
||||||
|
`scipy.signal.get_window`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
window_duration: float = Field(default=0.002, gt=0)
|
||||||
|
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||||
|
window_fn: str = "hann"
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyConfig(BaseConfig):
|
||||||
|
"""Configuration for frequency axis parameters.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
max_freq : int, default=120000
|
||||||
|
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
||||||
|
Frequencies above this value will be cropped. Must be > 0.
|
||||||
|
min_freq : int, default=10000
|
||||||
|
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||||
|
Frequencies below this value will be cropped. Must be >= 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_freq: int = Field(default=120_000, ge=0)
|
||||||
|
min_freq: int = Field(default=10_000, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class PcenConfig(BaseConfig):
|
||||||
|
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||||
|
|
||||||
|
name: Literal["pcen"] = "pcen"
|
||||||
|
time_constant: float = 0.4
|
||||||
|
gain: float = 0.98
|
||||||
|
bias: float = 2
|
||||||
|
power: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleAmplitudeConfig(BaseConfig):
|
||||||
|
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||||
|
scale: Literal["power", "db"] = "db"
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||||
|
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeConfig(BaseConfig):
|
||||||
|
name: Literal["resize_spec"] = "resize_spec"
|
||||||
|
height: int = 128
|
||||||
|
resize_factor: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class PeakNormalizeConfig(BaseConfig):
|
||||||
|
name: Literal["peak_normalize"] = "peak_normalize"
|
||||||
|
|
||||||
|
|
||||||
|
SpectrogramTransform = Annotated[
|
||||||
|
Union[
|
||||||
|
PcenConfig,
|
||||||
|
ScaleAmplitudeConfig,
|
||||||
|
SpectralMeanSubstractionConfig,
|
||||||
|
PeakNormalizeConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramConfig(BaseConfig):
|
||||||
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
|
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||||
|
transforms: Sequence[SpectrogramTransform] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
PcenConfig(),
|
||||||
|
SpectralMeanSubstractionConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingConfig(BaseConfig):
|
||||||
|
"""Unified configuration for the audio preprocessing pipeline.
|
||||||
|
|
||||||
|
Aggregates the configuration for both the initial audio processing stage
|
||||||
|
and the subsequent spectrogram generation stage.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
audio : AudioConfig
|
||||||
|
Configuration settings for the audio loading and initial waveform
|
||||||
|
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||||
|
Defaults to default `AudioConfig` settings if omitted.
|
||||||
|
spectrogram : SpectrogramConfig
|
||||||
|
Configuration settings for the spectrogram generation process
|
||||||
|
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||||
|
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_preprocessing_config(
|
||||||
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PreprocessingConfig:
|
||||||
|
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||||
86
src/batdetect2/preprocess/preprocessor.py
Normal file
86
src/batdetect2/preprocess/preprocessor.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.preprocess.audio import build_audio_pipeline
|
||||||
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
_spec_params_from_config,
|
||||||
|
build_spectrogram_pipeline,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StandardPreprocessor",
|
||||||
|
"build_preprocessor",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||||
|
"""Standard implementation of the `Preprocessor` protocol."""
|
||||||
|
|
||||||
|
input_samplerate: int
|
||||||
|
output_samplerate: float
|
||||||
|
|
||||||
|
max_freq: float
|
||||||
|
min_freq: float
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audio_pipeline: torch.nn.Module,
|
||||||
|
spectrogram_pipeline: SpectrogramPipeline,
|
||||||
|
input_samplerate: int,
|
||||||
|
output_samplerate: float,
|
||||||
|
max_freq: float,
|
||||||
|
min_freq: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.audio_pipeline = audio_pipeline
|
||||||
|
self.spectrogram_pipeline = spectrogram_pipeline
|
||||||
|
|
||||||
|
self.max_freq = max_freq
|
||||||
|
self.min_freq = min_freq
|
||||||
|
|
||||||
|
self.input_samplerate = input_samplerate
|
||||||
|
self.output_samplerate = output_samplerate
|
||||||
|
|
||||||
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
wav = self.audio_pipeline(wav)
|
||||||
|
return self.spectrogram_pipeline(wav)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
||||||
|
samplerate = config.audio.samplerate
|
||||||
|
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
||||||
|
factor = config.spectrogram.size.resize_factor
|
||||||
|
return samplerate * factor / hop_size
|
||||||
|
|
||||||
|
|
||||||
|
def build_preprocessor(
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
) -> PreprocessorProtocol:
|
||||||
|
"""Factory function to build the standard preprocessor from configuration."""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building preprocessor with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
samplerate = config.audio.samplerate
|
||||||
|
|
||||||
|
min_freq = config.spectrogram.frequencies.min_freq
|
||||||
|
max_freq = config.spectrogram.frequencies.max_freq
|
||||||
|
|
||||||
|
output_samplerate = compute_output_samplerate(config)
|
||||||
|
|
||||||
|
return StandardPreprocessor(
|
||||||
|
audio_pipeline=build_audio_pipeline(config.audio),
|
||||||
|
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||||
|
samplerate, config.spectrogram
|
||||||
|
),
|
||||||
|
input_samplerate=samplerate,
|
||||||
|
output_samplerate=output_samplerate,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
)
|
||||||
@ -1,63 +1,37 @@
|
|||||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||||
|
|
||||||
from typing import (
|
from typing import Callable, Optional
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
|
||||||
from batdetect2.preprocess.common import PeakNormalize
|
from batdetect2.preprocess.common import PeakNormalize
|
||||||
|
from batdetect2.preprocess.config import (
|
||||||
|
ScaleAmplitudeConfig,
|
||||||
|
SpectrogramConfig,
|
||||||
|
SpectrogramTransform,
|
||||||
|
STFTConfig,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"STFTConfig",
|
|
||||||
"FrequencyConfig",
|
|
||||||
"PcenConfig",
|
|
||||||
"SpectrogramConfig",
|
|
||||||
"build_spectrogram_builder",
|
"build_spectrogram_builder",
|
||||||
"MIN_FREQ",
|
"build_spectrogram_pipeline",
|
||||||
"MAX_FREQ",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
MIN_FREQ = 10_000
|
def build_spectrogram_builder(
|
||||||
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
|
samplerate: int,
|
||||||
|
conf: STFTConfig,
|
||||||
MAX_FREQ = 120_000
|
) -> torch.nn.Module:
|
||||||
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
|
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||||
|
return torchaudio.transforms.Spectrogram(
|
||||||
|
n_fft=n_fft,
|
||||||
class STFTConfig(BaseConfig):
|
hop_length=hop_length,
|
||||||
"""Configuration for the Short-Time Fourier Transform (STFT).
|
window_fn=get_spectrogram_window(conf.window_fn),
|
||||||
|
center=True,
|
||||||
Attributes
|
power=1,
|
||||||
----------
|
)
|
||||||
window_duration : float, default=0.002
|
|
||||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
|
||||||
> 0. Determines frequency resolution (longer window = finer frequency
|
|
||||||
resolution).
|
|
||||||
window_overlap : float, default=0.75
|
|
||||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
|
||||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
|
||||||
(higher overlap = finer time resolution).
|
|
||||||
window_fn : str, default="hann"
|
|
||||||
Name of the window function to apply before FFT calculation. Common
|
|
||||||
options include "hann", "hamming", "blackman". See
|
|
||||||
`scipy.signal.get_window`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
window_duration: float = Field(default=0.002, gt=0)
|
|
||||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
|
||||||
window_fn: str = "hann"
|
|
||||||
|
|
||||||
|
|
||||||
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||||
@ -87,37 +61,6 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
|||||||
return n_fft, hop_length
|
return n_fft, hop_length
|
||||||
|
|
||||||
|
|
||||||
def build_spectrogram_builder(
|
|
||||||
samplerate: int,
|
|
||||||
conf: STFTConfig,
|
|
||||||
) -> torch.nn.Module:
|
|
||||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
|
||||||
return torchaudio.transforms.Spectrogram(
|
|
||||||
n_fft=n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
window_fn=get_spectrogram_window(conf.window_fn),
|
|
||||||
center=True,
|
|
||||||
power=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
|
||||||
"""Configuration for frequency axis parameters.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
max_freq : int, default=120000
|
|
||||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
|
||||||
Frequencies above this value will be cropped. Must be > 0.
|
|
||||||
min_freq : int, default=10000
|
|
||||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
|
||||||
Frequencies below this value will be cropped. Must be >= 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_freq: int = Field(default=120_000, ge=0)
|
|
||||||
min_freq: int = Field(default=10_000, ge=0)
|
|
||||||
|
|
||||||
|
|
||||||
def _frequency_to_index(
|
def _frequency_to_index(
|
||||||
freq: float,
|
freq: float,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
@ -164,16 +107,6 @@ class FrequencyClip(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseConfig):
|
|
||||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
|
||||||
|
|
||||||
name: Literal["pcen"] = "pcen"
|
|
||||||
time_constant: float = 0.4
|
|
||||||
gain: float = 0.98
|
|
||||||
bias: float = 2
|
|
||||||
power: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class PCEN(torch.nn.Module):
|
class PCEN(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -231,11 +164,6 @@ def _compute_smoothing_constant(
|
|||||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
|
|
||||||
|
|
||||||
class ScaleAmplitudeConfig(BaseConfig):
|
|
||||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
|
||||||
scale: Literal["power", "db"] = "db"
|
|
||||||
|
|
||||||
|
|
||||||
class ToPower(torch.nn.Module):
|
class ToPower(torch.nn.Module):
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
return spec**2
|
return spec**2
|
||||||
@ -253,22 +181,12 @@ def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
|
||||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
|
||||||
|
|
||||||
|
|
||||||
class SpectralMeanSubstraction(torch.nn.Module):
|
class SpectralMeanSubstraction(torch.nn.Module):
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
mean = spec.mean(-1, keepdim=True)
|
mean = spec.mean(-1, keepdim=True)
|
||||||
return (spec - mean).clamp(min=0)
|
return (spec - mean).clamp(min=0)
|
||||||
|
|
||||||
|
|
||||||
class ResizeConfig(BaseConfig):
|
|
||||||
name: Literal["resize_spec"] = "resize_spec"
|
|
||||||
height: int = 128
|
|
||||||
resize_factor: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeSpec(torch.nn.Module):
|
class ResizeSpec(torch.nn.Module):
|
||||||
def __init__(self, height: int, time_factor: float):
|
def __init__(self, height: int, time_factor: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -295,33 +213,6 @@ class ResizeSpec(torch.nn.Module):
|
|||||||
return resized
|
return resized
|
||||||
|
|
||||||
|
|
||||||
class PeakNormalizeConfig(BaseConfig):
|
|
||||||
name: Literal["peak_normalize"] = "peak_normalize"
|
|
||||||
|
|
||||||
|
|
||||||
SpectrogramTransform = Annotated[
|
|
||||||
Union[
|
|
||||||
PcenConfig,
|
|
||||||
ScaleAmplitudeConfig,
|
|
||||||
SpectralMeanSubstractionConfig,
|
|
||||||
PeakNormalizeConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramConfig(BaseConfig):
|
|
||||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
|
||||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
|
||||||
transforms: Sequence[SpectrogramTransform] = Field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
PcenConfig(),
|
|
||||||
SpectralMeanSubstractionConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_spectrogram_transform_step(
|
def _build_spectrogram_transform_step(
|
||||||
step: SpectrogramTransform,
|
step: SpectrogramTransform,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from loguru import logger
|
|||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.conditions import build_sound_event_condition
|
from batdetect2.data.conditions import build_sound_event_condition
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
DEFAULT_CLASSES,
|
DEFAULT_CLASSES,
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional
|
|||||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import (
|
||||||
AllOfConfig,
|
AllOfConfig,
|
||||||
HasAllTagsConfig,
|
HasAllTagsConfig,
|
||||||
|
|||||||
@ -26,12 +26,17 @@ import numpy as np
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.arrays import spec_to_xarray
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing import (
|
||||||
from batdetect2.typing.targets import Position, ROITargetMapper, Size
|
AudioLoader,
|
||||||
from batdetect2.utils.arrays import spec_to_xarray
|
Position,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
ROITargetMapper,
|
||||||
|
Size,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Anchor",
|
"Anchor",
|
||||||
|
|||||||
@ -11,11 +11,10 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import scale_geometry, shift_geometry
|
from soundevent.geometry import scale_geometry, shift_geometry
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.arrays import adjust_width
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.train.clips import get_subclip_annotation
|
from batdetect2.train.clips import get_subclip_annotation
|
||||||
from batdetect2.typing import Augmentation
|
from batdetect2.typing import AudioLoader, Augmentation
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
from batdetect2.utils.arrays import adjust_width
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationConfig",
|
"AugmentationConfig",
|
||||||
|
|||||||
@ -10,10 +10,12 @@ from batdetect2.postprocess import get_raw_predictions
|
|||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import get_image_plotter
|
from batdetect2.train.logging import get_image_plotter
|
||||||
from batdetect2.typing.evaluate import ClipEvaluation
|
from batdetect2.typing import (
|
||||||
from batdetect2.typing.models import ModelOutput
|
ClipEvaluation,
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
ModelOutput,
|
||||||
from batdetect2.typing.train import TrainExample
|
RawPrediction,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
|
|||||||
@ -6,8 +6,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.data._core import Registry
|
|
||||||
from batdetect2.typing import ClipperProtocol
|
from batdetect2.typing import ClipperProtocol
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Optional, Union
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
@ -80,7 +80,6 @@ class OptimizerConfig(BaseConfig):
|
|||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(BaseConfig):
|
||||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||||
|
|
||||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from batdetect2.plotting.clips import build_audio_loader
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
RandomAudioSource,
|
RandomAudioSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
@ -14,10 +14,14 @@ from batdetect2.train.augmentations import (
|
|||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
from batdetect2.typing import (
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
AudioLoader,
|
||||||
from batdetect2.typing.train import Augmentation, ClipLabeller
|
Augmentation,
|
||||||
from batdetect2.utils.arrays import adjust_width
|
ClipLabeller,
|
||||||
|
ClipperProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingDataset",
|
"TrainingDataset",
|
||||||
|
|||||||
@ -13,14 +13,10 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
|
||||||
ClipLabeller,
|
|
||||||
Heatmaps,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from loguru import logger
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from loguru import logger
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -1,4 +1,8 @@
|
|||||||
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
|
from batdetect2.typing.evaluate import (
|
||||||
|
ClipEvaluation,
|
||||||
|
MatchEvaluation,
|
||||||
|
MetricsProtocol,
|
||||||
|
)
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
@ -10,9 +14,11 @@ from batdetect2.typing.preprocess import (
|
|||||||
AudioLoader,
|
AudioLoader,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
SpectrogramBuilder,
|
SpectrogramBuilder,
|
||||||
|
SpectrogramPipeline,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import (
|
from batdetect2.typing.targets import (
|
||||||
Position,
|
Position,
|
||||||
|
ROITargetMapper,
|
||||||
Size,
|
Size,
|
||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
@ -34,6 +40,7 @@ __all__ = [
|
|||||||
"Augmentation",
|
"Augmentation",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
"BatDetect2Prediction",
|
"BatDetect2Prediction",
|
||||||
|
"ClipEvaluation",
|
||||||
"ClipLabeller",
|
"ClipLabeller",
|
||||||
"ClipperProtocol",
|
"ClipperProtocol",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
@ -47,12 +54,14 @@ __all__ = [
|
|||||||
"Position",
|
"Position",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"PreprocessorProtocol",
|
"PreprocessorProtocol",
|
||||||
|
"ROITargetMapper",
|
||||||
"RawPrediction",
|
"RawPrediction",
|
||||||
"Size",
|
"Size",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
"SoundEventFilter",
|
"SoundEventFilter",
|
||||||
"SpectrogramBuilder",
|
"SpectrogramBuilder",
|
||||||
|
"SpectrogramPipeline",
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user