Exported types at module level

This commit is contained in:
mbsantiago 2026-03-18 13:08:28 +00:00
parent 6276a8884e
commit 0bf809e376
9 changed files with 120 additions and 52 deletions

View File

@ -7,35 +7,35 @@ import torch
from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.audio import AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import load_dataset_from_config
from batdetect2.data.datasets import Dataset
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import (
DEFAULT_EVAL_DIR,
EvaluatorProtocol,
build_evaluator,
run_evaluate,
)
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
OutputTransformProtocol,
build_output_formatter,
build_output_transform,
get_output_formatter,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.types import (
from batdetect2.postprocess import (
ClipDetections,
Detection,
PostprocessorProtocol,
build_postprocessor,
)
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
load_model_from_checkpoint,
@ -168,6 +168,9 @@ class BatDetect2API:
def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path)
def load_recording(self, recording: data.Recording) -> np.ndarray:
return self.audio_loader.load_recording(recording)
def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip)
@ -178,7 +181,7 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> ClipDetections:
def process_file(self, audio_file: data.PathLike) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav)
@ -230,6 +233,7 @@ class BatDetect2API:
def process_files(
self,
audio_files: Sequence[data.PathLike],
batch_size: int | None = None,
num_workers: int = 0,
) -> list[ClipDetections]:
return process_file_list(

View File

@ -5,8 +5,11 @@ from batdetect2.audio.loader import (
SoundEventAudioLoader,
build_audio_loader,
)
from batdetect2.audio.types import AudioLoader, ClipperProtocol
__all__ = [
"AudioLoader",
"ClipperProtocol",
"TARGET_SAMPLERATE_HZ",
"AudioConfig",
"SoundEventAudioLoader",

View File

@ -7,6 +7,7 @@ from batdetect2.data.annotations import (
load_annotated_dataset,
)
from batdetect2.data.datasets import (
Dataset,
DatasetConfig,
load_dataset,
load_dataset_config,
@ -19,6 +20,7 @@ from batdetect2.data.summary import (
)
__all__ = [
"Dataset",
"AOEFAnnotations",
"AnnotatedDataset",
"AnnotationFormats",

View File

@ -2,14 +2,30 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task
from batdetect2.evaluate.types import (
AffinityFunction,
ClipMatches,
EvaluationTaskProtocol,
EvaluatorProtocol,
MetricsProtocol,
PlotterProtocol,
)
__all__ = [
"AffinityFunction",
"ClipMatches",
"DEFAULT_EVAL_DIR",
"EvaluationConfig",
"EvaluationTaskProtocol",
"Evaluator",
"EvaluatorProtocol",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"PlotterProtocol",
"TaskConfig",
"build_evaluator",
"build_task",
"run_evaluate",
"load_evaluation_config",
"DEFAULT_EVAL_DIR",
"run_evaluate",
]

View File

@ -1,49 +1,52 @@
from typing import TYPE_CHECKING, List, Optional, Sequence
from typing import Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio import AudioConfig
from batdetect2.audio.loader import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.inference import InferenceConfig
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.outputs import (
OutputsConfig,
OutputTransformProtocol,
build_output_transform,
)
from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.targets import build_targets
from batdetect2.targets.types import TargetProtocol
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
def run_batch_inference(
model,
model: Model,
clips: Sequence[data.Clip],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
output_transform: Optional[OutputTransformProtocol] = None,
targets: TargetProtocol | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
output_transform: OutputTransformProtocol | None = None,
output_config: OutputsConfig | None = None,
inference_config: InferenceConfig | None = None,
num_workers: int = 1,
batch_size: int | None = None,
) -> List[ClipDetections]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
) -> list[ClipDetections]:
audio_config = audio_config or AudioConfig(
samplerate=model.preprocessor.input_samplerate,
)
output_config = output_config or OutputsConfig()
inference_config = inference_config or InferenceConfig()
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
targets = targets or build_targets()
output_transform = output_transform or build_output_transform(
config=config.outputs.transform,
config=output_config.transform,
targets=targets,
)
@ -51,7 +54,7 @@ def run_batch_inference(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config.inference.loader,
config=inference_config.loader,
num_workers=num_workers,
batch_size=batch_size,
)
@ -72,13 +75,18 @@ def run_batch_inference(
def process_file_list(
model: Model,
paths: Sequence[data.PathLike],
config: "BatDetect2Config",
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
targets: TargetProtocol | None = None,
audio_loader: AudioLoader | None = None,
audio_config: AudioConfig | None = None,
preprocessor: PreprocessorProtocol | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
output_transform: OutputTransformProtocol | None = None,
batch_size: int | None = None,
num_workers: int = 0,
) -> List[ClipDetections]:
clip_config = config.inference.clipping
) -> list[ClipDetections]:
inference_config = inference_config or InferenceConfig()
clip_config = inference_config.clipping
clips = get_clips_from_files(
paths,
duration=clip_config.duration,
@ -92,6 +100,10 @@ def process_file_list(
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
batch_size=batch_size,
num_workers=num_workers,
output_config=output_config,
audio_config=audio_config,
output_transform=output_transform,
inference_config=inference_config,
)

View File

@ -13,11 +13,15 @@ from batdetect2.outputs.transforms import (
OutputTransformConfig,
build_output_transform,
)
from batdetect2.outputs.types import OutputTransformProtocol
from batdetect2.outputs.types import (
OutputFormatterProtocol,
OutputTransformProtocol,
)
__all__ = [
"BatDetect2OutputConfig",
"OutputFormatConfig",
"OutputFormatterProtocol",
"OutputTransformConfig",
"OutputTransformProtocol",
"OutputsConfig",

View File

@ -9,10 +9,26 @@ from batdetect2.postprocess.postprocessor import (
Postprocessor,
build_postprocessor,
)
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsArray,
ClipDetectionsTensor,
ClipPrediction,
Detection,
GeometryDecoder,
PostprocessorProtocol,
)
__all__ = [
"ClipDetections",
"ClipDetectionsArray",
"ClipDetectionsTensor",
"ClipPrediction",
"Detection",
"GeometryDecoder",
"PostprocessConfig",
"Postprocessor",
"PostprocessorProtocol",
"build_postprocessor",
"load_postprocess_config",
"non_max_suppression",

View File

@ -7,8 +7,10 @@ from batdetect2.preprocess.config import (
)
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"PreprocessorProtocol",
"MAX_FREQ",
"MIN_FREQ",
"PreprocessingConfig",

View File

@ -1,8 +1,6 @@
"""BatDetect2 Target Definition system."""
from batdetect2.targets.classes import (
SoundEventDecoder,
SoundEventEncoder,
TargetClassConfig,
build_sound_event_decoder,
build_sound_event_encoder,
@ -12,7 +10,6 @@ from batdetect2.targets.config import TargetConfig, load_target_config
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
ROIMapperConfig,
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.targets import (
@ -27,15 +24,27 @@ from batdetect2.targets.terms import (
generic_class,
individual,
)
from batdetect2.targets.types import (
Position,
ROITargetMapper,
Size,
SoundEventDecoder,
SoundEventEncoder,
TargetProtocol,
)
__all__ = [
"AnchorBBoxMapperConfig",
"Position",
"ROIMapperConfig",
"ROITargetMapper",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"TargetClassConfig",
"TargetConfig",
"TargetProtocol",
"Targets",
"build_roi_mapper",
"build_sound_event_decoder",