Exported types at module level

This commit is contained in:
mbsantiago 2026-03-18 13:08:28 +00:00
parent 6276a8884e
commit 0bf809e376
9 changed files with 120 additions and 52 deletions

View File

@ -7,35 +7,35 @@ import torch
from soundevent import data from soundevent import data
from soundevent.audio.files import get_audio_files from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader from batdetect2.audio import AudioLoader, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs from batdetect2.core import merge_configs
from batdetect2.data import load_dataset_from_config from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.data.datasets import Dataset from batdetect2.evaluate import (
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate DEFAULT_EVAL_DIR,
from batdetect2.evaluate.types import EvaluatorProtocol EvaluatorProtocol,
build_evaluator,
run_evaluate,
)
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.outputs import ( from batdetect2.outputs import (
OutputFormatConfig, OutputFormatConfig,
OutputFormatterProtocol,
OutputTransformProtocol, OutputTransformProtocol,
build_output_formatter, build_output_formatter,
build_output_transform, build_output_transform,
get_output_formatter, get_output_formatter,
) )
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.postprocess import (
from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.types import (
ClipDetections, ClipDetections,
Detection, Detection,
PostprocessorProtocol, PostprocessorProtocol,
build_postprocessor,
) )
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import ( from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR, DEFAULT_CHECKPOINT_DIR,
load_model_from_checkpoint, load_model_from_checkpoint,
@ -168,6 +168,9 @@ class BatDetect2API:
def load_audio(self, path: data.PathLike) -> np.ndarray: def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path) return self.audio_loader.load_file(path)
def load_recording(self, recording: data.Recording) -> np.ndarray:
return self.audio_loader.load_recording(recording)
def load_clip(self, clip: data.Clip) -> np.ndarray: def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip) return self.audio_loader.load_clip(clip)
@ -178,7 +181,7 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0) tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor) return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> ClipDetections: def process_file(self, audio_file: data.PathLike) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False) recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording) wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav) detections = self.process_audio(wav)
@ -230,6 +233,7 @@ class BatDetect2API:
def process_files( def process_files(
self, self,
audio_files: Sequence[data.PathLike], audio_files: Sequence[data.PathLike],
batch_size: int | None = None,
num_workers: int = 0, num_workers: int = 0,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
return process_file_list( return process_file_list(

View File

@ -5,8 +5,11 @@ from batdetect2.audio.loader import (
SoundEventAudioLoader, SoundEventAudioLoader,
build_audio_loader, build_audio_loader,
) )
from batdetect2.audio.types import AudioLoader, ClipperProtocol
__all__ = [ __all__ = [
"AudioLoader",
"ClipperProtocol",
"TARGET_SAMPLERATE_HZ", "TARGET_SAMPLERATE_HZ",
"AudioConfig", "AudioConfig",
"SoundEventAudioLoader", "SoundEventAudioLoader",

View File

@ -7,6 +7,7 @@ from batdetect2.data.annotations import (
load_annotated_dataset, load_annotated_dataset,
) )
from batdetect2.data.datasets import ( from batdetect2.data.datasets import (
Dataset,
DatasetConfig, DatasetConfig,
load_dataset, load_dataset,
load_dataset_config, load_dataset_config,
@ -19,6 +20,7 @@ from batdetect2.data.summary import (
) )
__all__ = [ __all__ = [
"Dataset",
"AOEFAnnotations", "AOEFAnnotations",
"AnnotatedDataset", "AnnotatedDataset",
"AnnotationFormats", "AnnotationFormats",

View File

@ -2,14 +2,30 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task from batdetect2.evaluate.tasks import TaskConfig, build_task
from batdetect2.evaluate.types import (
AffinityFunction,
ClipMatches,
EvaluationTaskProtocol,
EvaluatorProtocol,
MetricsProtocol,
PlotterProtocol,
)
__all__ = [ __all__ = [
"AffinityFunction",
"ClipMatches",
"DEFAULT_EVAL_DIR",
"EvaluationConfig", "EvaluationConfig",
"EvaluationTaskProtocol",
"Evaluator", "Evaluator",
"EvaluatorProtocol",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"PlotterProtocol",
"TaskConfig", "TaskConfig",
"build_evaluator", "build_evaluator",
"build_task", "build_task",
"run_evaluate",
"load_evaluation_config", "load_evaluation_config",
"DEFAULT_EVAL_DIR", "run_evaluate",
] ]

View File

@ -1,49 +1,52 @@
from typing import TYPE_CHECKING, List, Optional, Sequence from typing import Sequence
from lightning import Trainer from lightning import Trainer
from soundevent import data from soundevent import data
from batdetect2.audio import AudioConfig
from batdetect2.audio.loader import build_audio_loader from batdetect2.audio.loader import build_audio_loader
from batdetect2.audio.types import AudioLoader from batdetect2.audio.types import AudioLoader
from batdetect2.inference import InferenceConfig
from batdetect2.inference.clips import get_clips_from_files from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import (
OutputsConfig,
OutputTransformProtocol,
build_output_transform,
)
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.targets import build_targets
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
def run_batch_inference( def run_batch_inference(
model, model: Model,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
targets: Optional["TargetProtocol"] = None, targets: TargetProtocol | None = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional["BatDetect2Config"] = None, audio_config: AudioConfig | None = None,
output_transform: Optional[OutputTransformProtocol] = None, output_transform: OutputTransformProtocol | None = None,
output_config: OutputsConfig | None = None,
inference_config: InferenceConfig | None = None,
num_workers: int = 1, num_workers: int = 1,
batch_size: int | None = None, batch_size: int | None = None,
) -> List[ClipDetections]: ) -> list[ClipDetections]:
from batdetect2.config import BatDetect2Config audio_config = audio_config or AudioConfig(
samplerate=model.preprocessor.input_samplerate,
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
) )
output_config = output_config or OutputsConfig()
inference_config = inference_config or InferenceConfig()
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
targets = targets or build_targets()
output_transform = output_transform or build_output_transform( output_transform = output_transform or build_output_transform(
config=config.outputs.transform, config=output_config.transform,
targets=targets, targets=targets,
) )
@ -51,7 +54,7 @@ def run_batch_inference(
clips, clips,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.inference.loader, config=inference_config.loader,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
) )
@ -72,13 +75,18 @@ def run_batch_inference(
def process_file_list( def process_file_list(
model: Model, model: Model,
paths: Sequence[data.PathLike], paths: Sequence[data.PathLike],
config: "BatDetect2Config", targets: TargetProtocol | None = None,
targets: Optional["TargetProtocol"] = None, audio_loader: AudioLoader | None = None,
audio_loader: Optional["AudioLoader"] = None, audio_config: AudioConfig | None = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: PreprocessorProtocol | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
output_transform: OutputTransformProtocol | None = None,
batch_size: int | None = None,
num_workers: int = 0, num_workers: int = 0,
) -> List[ClipDetections]: ) -> list[ClipDetections]:
clip_config = config.inference.clipping inference_config = inference_config or InferenceConfig()
clip_config = inference_config.clipping
clips = get_clips_from_files( clips = get_clips_from_files(
paths, paths,
duration=clip_config.duration, duration=clip_config.duration,
@ -92,6 +100,10 @@ def process_file_list(
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
output_config=output_config,
audio_config=audio_config,
output_transform=output_transform,
inference_config=inference_config,
) )

View File

@ -13,11 +13,15 @@ from batdetect2.outputs.transforms import (
OutputTransformConfig, OutputTransformConfig,
build_output_transform, build_output_transform,
) )
from batdetect2.outputs.types import OutputTransformProtocol from batdetect2.outputs.types import (
OutputFormatterProtocol,
OutputTransformProtocol,
)
__all__ = [ __all__ = [
"BatDetect2OutputConfig", "BatDetect2OutputConfig",
"OutputFormatConfig", "OutputFormatConfig",
"OutputFormatterProtocol",
"OutputTransformConfig", "OutputTransformConfig",
"OutputTransformProtocol", "OutputTransformProtocol",
"OutputsConfig", "OutputsConfig",

View File

@ -9,10 +9,26 @@ from batdetect2.postprocess.postprocessor import (
Postprocessor, Postprocessor,
build_postprocessor, build_postprocessor,
) )
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsArray,
ClipDetectionsTensor,
ClipPrediction,
Detection,
GeometryDecoder,
PostprocessorProtocol,
)
__all__ = [ __all__ = [
"ClipDetections",
"ClipDetectionsArray",
"ClipDetectionsTensor",
"ClipPrediction",
"Detection",
"GeometryDecoder",
"PostprocessConfig", "PostprocessConfig",
"Postprocessor", "Postprocessor",
"PostprocessorProtocol",
"build_postprocessor", "build_postprocessor",
"load_postprocess_config", "load_postprocess_config",
"non_max_suppression", "non_max_suppression",

View File

@ -7,8 +7,10 @@ from batdetect2.preprocess.config import (
) )
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [ __all__ = [
"PreprocessorProtocol",
"MAX_FREQ", "MAX_FREQ",
"MIN_FREQ", "MIN_FREQ",
"PreprocessingConfig", "PreprocessingConfig",

View File

@ -1,8 +1,6 @@
"""BatDetect2 Target Definition system.""" """BatDetect2 Target Definition system."""
from batdetect2.targets.classes import ( from batdetect2.targets.classes import (
SoundEventDecoder,
SoundEventEncoder,
TargetClassConfig, TargetClassConfig,
build_sound_event_decoder, build_sound_event_decoder,
build_sound_event_encoder, build_sound_event_encoder,
@ -12,7 +10,6 @@ from batdetect2.targets.config import TargetConfig, load_target_config
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
AnchorBBoxMapperConfig, AnchorBBoxMapperConfig,
ROIMapperConfig, ROIMapperConfig,
ROITargetMapper,
build_roi_mapper, build_roi_mapper,
) )
from batdetect2.targets.targets import ( from batdetect2.targets.targets import (
@ -27,15 +24,27 @@ from batdetect2.targets.terms import (
generic_class, generic_class,
individual, individual,
) )
from batdetect2.targets.types import (
Position,
ROITargetMapper,
Size,
SoundEventDecoder,
SoundEventEncoder,
TargetProtocol,
)
__all__ = [ __all__ = [
"AnchorBBoxMapperConfig", "AnchorBBoxMapperConfig",
"Position",
"ROIMapperConfig", "ROIMapperConfig",
"ROITargetMapper", "ROITargetMapper",
"Size",
"SoundEventDecoder", "SoundEventDecoder",
"SoundEventEncoder", "SoundEventEncoder",
"SoundEventFilter",
"TargetClassConfig", "TargetClassConfig",
"TargetConfig", "TargetConfig",
"TargetProtocol",
"Targets", "Targets",
"build_roi_mapper", "build_roi_mapper",
"build_sound_event_decoder", "build_sound_event_decoder",