mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Exported types at module level
This commit is contained in:
parent
6276a8884e
commit
0bf809e376
@ -7,35 +7,35 @@ import torch
|
||||
from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.audio import AudioLoader, build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.data import Dataset, load_dataset_from_config
|
||||
from batdetect2.evaluate import (
|
||||
DEFAULT_EVAL_DIR,
|
||||
EvaluatorProtocol,
|
||||
build_evaluator,
|
||||
run_evaluate,
|
||||
)
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
OutputTransformProtocol,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.postprocess.types import (
|
||||
from batdetect2.postprocess import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
PostprocessorProtocol,
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import TargetProtocol, build_targets
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
load_model_from_checkpoint,
|
||||
@ -168,6 +168,9 @@ class BatDetect2API:
|
||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||
return self.audio_loader.load_file(path)
|
||||
|
||||
def load_recording(self, recording: data.Recording) -> np.ndarray:
|
||||
return self.audio_loader.load_recording(recording)
|
||||
|
||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||
return self.audio_loader.load_clip(clip)
|
||||
|
||||
@ -178,7 +181,7 @@ class BatDetect2API:
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
def process_file(self, audio_file: str) -> ClipDetections:
|
||||
def process_file(self, audio_file: data.PathLike) -> ClipDetections:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
wav = self.audio_loader.load_recording(recording)
|
||||
detections = self.process_audio(wav)
|
||||
@ -230,6 +233,7 @@ class BatDetect2API:
|
||||
def process_files(
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> list[ClipDetections]:
|
||||
return process_file_list(
|
||||
|
||||
@ -5,8 +5,11 @@ from batdetect2.audio.loader import (
|
||||
SoundEventAudioLoader,
|
||||
build_audio_loader,
|
||||
)
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
|
||||
__all__ = [
|
||||
"AudioLoader",
|
||||
"ClipperProtocol",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"AudioConfig",
|
||||
"SoundEventAudioLoader",
|
||||
|
||||
@ -7,6 +7,7 @@ from batdetect2.data.annotations import (
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.datasets import (
|
||||
Dataset,
|
||||
DatasetConfig,
|
||||
load_dataset,
|
||||
load_dataset_config,
|
||||
@ -19,6 +20,7 @@ from batdetect2.data.summary import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Dataset",
|
||||
"AOEFAnnotations",
|
||||
"AnnotatedDataset",
|
||||
"AnnotationFormats",
|
||||
|
||||
@ -2,14 +2,30 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||
from batdetect2.evaluate.types import (
|
||||
AffinityFunction,
|
||||
ClipMatches,
|
||||
EvaluationTaskProtocol,
|
||||
EvaluatorProtocol,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AffinityFunction",
|
||||
"ClipMatches",
|
||||
"DEFAULT_EVAL_DIR",
|
||||
"EvaluationConfig",
|
||||
"EvaluationTaskProtocol",
|
||||
"Evaluator",
|
||||
"EvaluatorProtocol",
|
||||
"MatchEvaluation",
|
||||
"MatcherProtocol",
|
||||
"MetricsProtocol",
|
||||
"PlotterProtocol",
|
||||
"TaskConfig",
|
||||
"build_evaluator",
|
||||
"build_task",
|
||||
"run_evaluate",
|
||||
"load_evaluation_config",
|
||||
"DEFAULT_EVAL_DIR",
|
||||
"run_evaluate",
|
||||
]
|
||||
|
||||
@ -1,49 +1,52 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.audio.loader import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
model,
|
||||
model: Model,
|
||||
clips: Sequence[data.Clip],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
output_transform: Optional[OutputTransformProtocol] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
num_workers: int = 1,
|
||||
batch_size: int | None = None,
|
||||
) -> List[ClipDetections]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
) -> list[ClipDetections]:
|
||||
audio_config = audio_config or AudioConfig(
|
||||
samplerate=model.preprocessor.input_samplerate,
|
||||
)
|
||||
output_config = output_config or OutputsConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
|
||||
targets = targets or build_targets()
|
||||
output_transform = output_transform or build_output_transform(
|
||||
config=config.outputs.transform,
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
@ -51,7 +54,7 @@ def run_batch_inference(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config.inference.loader,
|
||||
config=inference_config.loader,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
@ -72,13 +75,18 @@ def run_batch_inference(
|
||||
def process_file_list(
|
||||
model: Model,
|
||||
paths: Sequence[data.PathLike],
|
||||
config: "BatDetect2Config",
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> List[ClipDetections]:
|
||||
clip_config = config.inference.clipping
|
||||
) -> list[ClipDetections]:
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
clip_config = inference_config.clipping
|
||||
clips = get_clips_from_files(
|
||||
paths,
|
||||
duration=clip_config.duration,
|
||||
@ -92,6 +100,10 @@ def process_file_list(
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
output_config=output_config,
|
||||
audio_config=audio_config,
|
||||
output_transform=output_transform,
|
||||
inference_config=inference_config,
|
||||
)
|
||||
|
||||
@ -13,11 +13,15 @@ from batdetect2.outputs.transforms import (
|
||||
OutputTransformConfig,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputTransformProtocol
|
||||
from batdetect2.outputs.types import (
|
||||
OutputFormatterProtocol,
|
||||
OutputTransformProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2OutputConfig",
|
||||
"OutputFormatConfig",
|
||||
"OutputFormatterProtocol",
|
||||
"OutputTransformConfig",
|
||||
"OutputTransformProtocol",
|
||||
"OutputsConfig",
|
||||
|
||||
@ -9,10 +9,26 @@ from batdetect2.postprocess.postprocessor import (
|
||||
Postprocessor,
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
ClipDetectionsArray,
|
||||
ClipDetectionsTensor,
|
||||
ClipPrediction,
|
||||
Detection,
|
||||
GeometryDecoder,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ClipDetections",
|
||||
"ClipDetectionsArray",
|
||||
"ClipDetectionsTensor",
|
||||
"ClipPrediction",
|
||||
"Detection",
|
||||
"GeometryDecoder",
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"PostprocessorProtocol",
|
||||
"build_postprocessor",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
|
||||
@ -7,8 +7,10 @@ from batdetect2.preprocess.config import (
|
||||
)
|
||||
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
|
||||
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"PreprocessorProtocol",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"PreprocessingConfig",
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
"""BatDetect2 Target Definition system."""
|
||||
|
||||
from batdetect2.targets.classes import (
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
TargetClassConfig,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
@ -12,7 +10,6 @@ from batdetect2.targets.config import TargetConfig, load_target_config
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
ROIMapperConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.targets.targets import (
|
||||
@ -27,15 +24,27 @@ from batdetect2.targets.terms import (
|
||||
generic_class,
|
||||
individual,
|
||||
)
|
||||
from batdetect2.targets.types import (
|
||||
Position,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnchorBBoxMapperConfig",
|
||||
"Position",
|
||||
"ROIMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"TargetClassConfig",
|
||||
"TargetConfig",
|
||||
"TargetProtocol",
|
||||
"Targets",
|
||||
"build_roi_mapper",
|
||||
"build_sound_event_decoder",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user