diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 8a63d7b..4ced7fc 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -32,11 +32,11 @@ from batdetect2.train import ( ) from batdetect2.typing import ( AudioLoader, - BatDetect2Prediction, + ClipDetections, EvaluatorProtocol, PostprocessorProtocol, PreprocessorProtocol, - RawPrediction, + Detection, TargetProtocol, ) @@ -110,7 +110,7 @@ class BatDetect2API: experiment_name: str | None = None, run_name: str | None = None, save_predictions: bool = True, - ) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: + ) -> Tuple[Dict[str, float], List[List[Detection]]]: return evaluate( self.model, test_annotations, @@ -128,7 +128,7 @@ class BatDetect2API: def evaluate_predictions( self, annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], output_dir: data.PathLike | None = None, ): clip_evals = self.evaluator.evaluate( @@ -170,24 +170,24 @@ class BatDetect2API: tensor = torch.tensor(audio).unsqueeze(0) return self.preprocessor(tensor) - def process_file(self, audio_file: str) -> BatDetect2Prediction: + def process_file(self, audio_file: str) -> ClipDetections: recording = data.Recording.from_file(audio_file, compute_hash=False) wav = self.audio_loader.load_recording(recording) detections = self.process_audio(wav) - return BatDetect2Prediction( + return ClipDetections( clip=data.Clip( uuid=recording.uuid, recording=recording, start_time=0, end_time=recording.duration, ), - predictions=detections, + detections=detections, ) def process_audio( self, audio: np.ndarray, - ) -> List[RawPrediction]: + ) -> List[Detection]: spec = self.generate_spectrogram(audio) return self.process_spectrogram(spec) @@ -195,7 +195,7 @@ class BatDetect2API: self, spec: torch.Tensor, start_time: float = 0, - ) -> List[RawPrediction]: + ) -> List[Detection]: if spec.ndim == 4 and spec.shape[0] > 1: raise ValueError("Batched spectrograms not supported.") @@ -214,7 +214,7 @@ class BatDetect2API: def process_directory( self, audio_dir: data.PathLike, - ) -> List[BatDetect2Prediction]: + ) -> List[ClipDetections]: files = list(get_audio_files(audio_dir)) return self.process_files(files) @@ -222,7 +222,7 @@ class BatDetect2API: self, audio_files: Sequence[data.PathLike], num_workers: int | None = None, - ) -> List[BatDetect2Prediction]: + ) -> List[ClipDetections]: return process_file_list( self.model, audio_files, @@ -238,7 +238,7 @@ class BatDetect2API: clips: Sequence[data.Clip], batch_size: int | None = None, num_workers: int | None = None, - ) -> List[BatDetect2Prediction]: + ) -> List[ClipDetections]: return run_batch_inference( self.model, clips, @@ -252,7 +252,7 @@ class BatDetect2API: def save_predictions( self, - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], path: data.PathLike, audio_dir: data.PathLike | None = None, format: str | None = None, @@ -274,7 +274,7 @@ class BatDetect2API: def load_predictions( self, path: data.PathLike, - ) -> List[BatDetect2Prediction]: + ) -> List[ClipDetections]: return self.formatter.load(path) @classmethod diff --git a/src/batdetect2/audio/loader.py b/src/batdetect2/audio/loader.py index 0149853..e8c46a0 100644 --- a/src/batdetect2/audio/loader.py +++ b/src/batdetect2/audio/loader.py @@ -1,4 +1,3 @@ - import numpy as np from numpy.typing import DTypeLike from pydantic import Field diff --git a/src/batdetect2/core/arrays.py b/src/batdetect2/core/arrays.py index 4bcc5d5..8469b08 100644 --- a/src/batdetect2/core/arrays.py +++ b/src/batdetect2/core/arrays.py @@ -1,4 +1,3 @@ - import numpy as np import torch import xarray as xr diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py index 6556d10..015ea2f 100644 --- a/src/batdetect2/data/conditions.py +++ b/src/batdetect2/data/conditions.py @@ -264,7 +264,14 @@ class Not: SoundEventConditionConfig = Annotated[ - HasTagConfig | HasAllTagsConfig | HasAnyTagConfig | DurationConfig | FrequencyConfig | AllOfConfig | AnyOfConfig | NotConfig, + HasTagConfig + | HasAllTagsConfig + | HasAnyTagConfig + | DurationConfig + | FrequencyConfig + | AllOfConfig + | AnyOfConfig + | NotConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/data/predictions/batdetect2.py b/src/batdetect2/data/predictions/batdetect2.py index 7dc58ad..6139181 100644 --- a/src/batdetect2/data/predictions/batdetect2.py +++ b/src/batdetect2/data/predictions/batdetect2.py @@ -13,9 +13,9 @@ from batdetect2.data.predictions.base import ( ) from batdetect2.targets import terms from batdetect2.typing import ( - BatDetect2Prediction, + ClipDetections, OutputFormatterProtocol, - RawPrediction, + Detection, TargetProtocol, ) @@ -113,7 +113,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): self.annotation_note = annotation_note def format( - self, predictions: Sequence[BatDetect2Prediction] + self, predictions: Sequence[ClipDetections] ) -> List[FileAnnotation]: return [ self.format_prediction(prediction) for prediction in predictions @@ -164,14 +164,12 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): highest_scoring = max(annotations, key=lambda x: x["class_prob"]) return highest_scoring["class"] - def format_prediction( - self, prediction: BatDetect2Prediction - ) -> FileAnnotation: + def format_prediction(self, prediction: ClipDetections) -> FileAnnotation: recording = prediction.clip.recording annotations = [ self.format_sound_event_prediction(pred) - for pred in prediction.predictions + for pred in prediction.detections ] return FileAnnotation( @@ -196,7 +194,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): ) # type: ignore def format_sound_event_prediction( - self, prediction: RawPrediction + self, prediction: Detection ) -> Annotation: start_time, low_freq, end_time, high_freq = compute_bounds( prediction.geometry diff --git a/src/batdetect2/data/predictions/parquet.py b/src/batdetect2/data/predictions/parquet.py index 52bdc71..d01d547 100644 --- a/src/batdetect2/data/predictions/parquet.py +++ b/src/batdetect2/data/predictions/parquet.py @@ -14,9 +14,9 @@ from batdetect2.data.predictions.base import ( prediction_formatters, ) from batdetect2.typing import ( - BatDetect2Prediction, + ClipDetections, OutputFormatterProtocol, - RawPrediction, + Detection, TargetProtocol, ) @@ -29,7 +29,7 @@ class ParquetOutputConfig(BaseConfig): include_geometry: bool = True -class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): +class ParquetFormatter(OutputFormatterProtocol[ClipDetections]): def __init__( self, targets: TargetProtocol, @@ -44,13 +44,13 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): def format( self, - predictions: Sequence[BatDetect2Prediction], - ) -> List[BatDetect2Prediction]: + predictions: Sequence[ClipDetections], + ) -> List[ClipDetections]: return list(predictions) def save( self, - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], path: data.PathLike, audio_dir: data.PathLike | None = None, ) -> None: @@ -61,24 +61,26 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): # Ensure the file has .parquet extension if it's a file path if path.suffix != ".parquet": - # If it's a directory, we might want to save as a partitioned dataset or a single file inside - # For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet' - if path.is_dir() or not path.suffix: - path = path / "predictions.parquet" + # If it's a directory, we might want to save as a partitioned dataset or a single file inside + # For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet' + if path.is_dir() or not path.suffix: + path = path / "predictions.parquet" rows = [] for prediction in predictions: clip = prediction.clip recording = clip.recording - + if audio_dir is not None: recording = recording.model_copy( - update=dict(path=make_path_relative(recording.path, audio_dir)) + update=dict( + path=make_path_relative(recording.path, audio_dir) + ) ) recording_json = recording.model_dump_json(exclude_none=True) - for pred in prediction.predictions: + for pred in prediction.detections: row = { "clip_uuid": str(clip.uuid), "clip_start_time": clip.start_time, @@ -96,18 +98,18 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): row["low_freq"] = low_freq row["end_time"] = end_time row["high_freq"] = high_freq - + # Store full geometry as JSON row["geometry"] = pred.geometry.model_dump_json() if self.include_class_scores: row["class_scores"] = pred.class_scores.tolist() - + if self.include_features: row["features"] = pred.features.tolist() rows.append(row) - + if not rows: logger.warning("No predictions to save.") return @@ -116,16 +118,16 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): logger.info(f"Saving {len(df)} predictions to {path}") df.to_parquet(path, index=False) - def load(self, path: data.PathLike) -> List[BatDetect2Prediction]: + def load(self, path: data.PathLike) -> List[ClipDetections]: path = Path(path) if path.is_dir(): - # Try to find parquet files - files = list(path.glob("*.parquet")) - if not files: - return [] - # Read all and concatenate - dfs = [pd.read_parquet(f) for f in files] - df = pd.concat(dfs, ignore_index=True) + # Try to find parquet files + files = list(path.glob("*.parquet")) + if not files: + return [] + # Read all and concatenate + dfs = [pd.read_parquet(f) for f in files] + df = pd.concat(dfs, ignore_index=True) else: df = pd.read_parquet(path) @@ -133,20 +135,19 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): for _, row in df.iterrows(): clip_uuid = row["clip_uuid"] - + if clip_uuid not in predictions_by_clip: - recording = data.Recording.model_validate_json(row["recording_info"]) + recording = data.Recording.model_validate_json( + row["recording_info"] + ) clip = data.Clip( uuid=UUID(clip_uuid), recording=recording, start_time=row["clip_start_time"], end_time=row["clip_end_time"], ) - predictions_by_clip[clip_uuid] = { - "clip": clip, - "preds": [] - } - + predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []} + # Reconstruct geometry if "geometry" in row and row["geometry"]: geometry = data.geometry_validate(row["geometry"]) @@ -156,14 +157,20 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): row["start_time"], row["low_freq"], row["end_time"], - row["high_freq"] + row["high_freq"], ] ) - class_scores = np.array(row["class_scores"]) if "class_scores" in row else np.zeros(len(self.targets.class_names)) - features = np.array(row["features"]) if "features" in row else np.zeros(0) + class_scores = ( + np.array(row["class_scores"]) + if "class_scores" in row + else np.zeros(len(self.targets.class_names)) + ) + features = ( + np.array(row["features"]) if "features" in row else np.zeros(0) + ) - pred = RawPrediction( + pred = Detection( geometry=geometry, detection_score=row["detection_score"], class_scores=class_scores, @@ -174,12 +181,11 @@ class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): results = [] for clip_data in predictions_by_clip.values(): results.append( - BatDetect2Prediction( - clip=clip_data["clip"], - predictions=clip_data["preds"] + ClipDetections( + clip=clip_data["clip"], detections=clip_data["preds"] ) ) - + return results @prediction_formatters.register(ParquetOutputConfig) diff --git a/src/batdetect2/data/predictions/raw.py b/src/batdetect2/data/predictions/raw.py index e509e84..60ece24 100644 --- a/src/batdetect2/data/predictions/raw.py +++ b/src/batdetect2/data/predictions/raw.py @@ -15,9 +15,9 @@ from batdetect2.data.predictions.base import ( prediction_formatters, ) from batdetect2.typing import ( - BatDetect2Prediction, + ClipDetections, OutputFormatterProtocol, - RawPrediction, + Detection, TargetProtocol, ) @@ -30,7 +30,7 @@ class RawOutputConfig(BaseConfig): include_geometry: bool = True -class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): +class RawFormatter(OutputFormatterProtocol[ClipDetections]): def __init__( self, targets: TargetProtocol, @@ -47,13 +47,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): def format( self, - predictions: Sequence[BatDetect2Prediction], - ) -> List[BatDetect2Prediction]: + predictions: Sequence[ClipDetections], + ) -> List[ClipDetections]: return list(predictions) def save( self, - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], path: data.PathLike, audio_dir: data.PathLike | None = None, ) -> None: @@ -68,10 +68,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): dataset = self.pred_to_xr(prediction, audio_dir) dataset.to_netcdf(path / f"{clip.uuid}.nc") - def load(self, path: data.PathLike) -> List[BatDetect2Prediction]: + def load(self, path: data.PathLike) -> List[ClipDetections]: path = Path(path) files = list(path.glob("*.nc")) - predictions: List[BatDetect2Prediction] = [] + predictions: List[ClipDetections] = [] for filepath in files: logger.debug(f"Loading clip predictions {filepath}") @@ -83,7 +83,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): def pred_to_xr( self, - prediction: BatDetect2Prediction, + prediction: ClipDetections, audio_dir: data.PathLike | None = None, ) -> xr.Dataset: clip = prediction.clip @@ -97,7 +97,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): data = defaultdict(list) - for pred in prediction.predictions: + for pred in prediction.detections: detection_id = str(uuid4()) data["detection_id"].append(detection_id) @@ -167,7 +167,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): }, ) - def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction: + def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections: clip_data = dataset clip_id = clip_data.clip_id.item() @@ -219,7 +219,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): features = np.zeros(0) sound_events.append( - RawPrediction( + Detection( geometry=geometry, detection_score=score, class_scores=class_scores, @@ -227,9 +227,9 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): ) ) - return BatDetect2Prediction( + return ClipDetections( clip=clip, - predictions=sound_events, + detections=sound_events, ) @prediction_formatters.register(RawOutputConfig) diff --git a/src/batdetect2/data/predictions/soundevent.py b/src/batdetect2/data/predictions/soundevent.py index e851685..d42adfe 100644 --- a/src/batdetect2/data/predictions/soundevent.py +++ b/src/batdetect2/data/predictions/soundevent.py @@ -9,9 +9,9 @@ from batdetect2.data.predictions.base import ( prediction_formatters, ) from batdetect2.typing import ( - BatDetect2Prediction, + ClipDetections, OutputFormatterProtocol, - RawPrediction, + Detection, TargetProtocol, ) @@ -35,7 +35,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): def format( self, - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], ) -> List[data.ClipPrediction]: return [ self.format_prediction(prediction) for prediction in predictions @@ -63,20 +63,20 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): def format_prediction( self, - prediction: BatDetect2Prediction, + prediction: ClipDetections, ) -> data.ClipPrediction: recording = prediction.clip.recording return data.ClipPrediction( clip=prediction.clip, sound_events=[ self.format_sound_event_prediction(pred, recording) - for pred in prediction.predictions + for pred in prediction.detections ], ) def format_sound_event_prediction( self, - prediction: RawPrediction, + prediction: Detection, recording: data.Recording, ) -> data.SoundEventPrediction: return data.SoundEventPrediction( @@ -89,7 +89,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): ) def get_sound_event_tags( - self, prediction: RawPrediction + self, prediction: Detection ) -> List[data.PredictedTag]: sorted_indices = np.argsort(prediction.class_scores)[::-1] diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py index a19250a..c8c99cd 100644 --- a/src/batdetect2/data/transforms.py +++ b/src/batdetect2/data/transforms.py @@ -221,7 +221,11 @@ class ApplyAll: SoundEventTransformConfig = Annotated[ - SetFrequencyBoundConfig | ReplaceTagConfig | MapTagValueConfig | ApplyIfConfig | ApplyAllConfig, + SetFrequencyBoundConfig + | ReplaceTagConfig + | MapTagValueConfig + | ApplyIfConfig + | ApplyAllConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index 8f5084e..976cb89 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -11,7 +11,7 @@ from soundevent.geometry import ( ) from batdetect2.core import BaseConfig, Registry -from batdetect2.typing import AffinityFunction, RawPrediction +from batdetect2.typing import AffinityFunction, Detection affinity_functions: Registry[AffinityFunction, []] = Registry( "affinity_function" @@ -42,7 +42,7 @@ class TimeAffinity(AffinityFunction): def __call__( self, - detection: RawPrediction, + detection: Detection, ground_truth: data.SoundEventAnnotation, ) -> float: target_geometry = ground_truth.sound_event.geometry @@ -77,7 +77,7 @@ class IntervalIOU(AffinityFunction): def __call__( self, - detection: RawPrediction, + detection: Detection, ground_truth: data.SoundEventAnnotation, ) -> float: target_geometry = ground_truth.sound_event.geometry @@ -120,7 +120,7 @@ class BBoxIOU(AffinityFunction): def __call__( self, - prediction: RawPrediction, + prediction: Detection, gt: data.SoundEventAnnotation, ): target_geometry = gt.sound_event.geometry @@ -168,7 +168,7 @@ class GeometricIOU(AffinityFunction): def __call__( self, - prediction: RawPrediction, + prediction: Detection, gt: data.SoundEventAnnotation, ): target_geometry = gt.sound_event.geometry diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 3a400f6..e69bb2b 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -12,7 +12,7 @@ from batdetect2.logging import build_logger from batdetect2.models import Model from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets -from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing import Detection if TYPE_CHECKING: from batdetect2.config import BatDetect2Config @@ -38,7 +38,7 @@ def evaluate( output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: str | None = None, run_name: str | None = None, -) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: +) -> Tuple[Dict[str, float], List[List[Detection]]]: from batdetect2.config import BatDetect2Config config = config or BatDetect2Config() diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index 6c5009c..39e960f 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -7,7 +7,7 @@ from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.tasks import build_task from batdetect2.targets import build_targets from batdetect2.typing import EvaluatorProtocol, TargetProtocol -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing.postprocess import ClipDetections __all__ = [ "Evaluator", @@ -27,7 +27,7 @@ class Evaluator: def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], ) -> List[Any]: return [ task.evaluate(clip_annotations, predictions) for task in self.tasks diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index 231e65d..f8d51af 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger from batdetect2.models import Model from batdetect2.postprocess import to_raw_predictions from batdetect2.typing import EvaluatorProtocol -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing.postprocess import ClipDetections class EvaluationModule(LightningModule): @@ -24,7 +24,7 @@ class EvaluationModule(LightningModule): self.evaluator = evaluator self.clip_annotations: List[data.ClipAnnotation] = [] - self.predictions: List[BatDetect2Prediction] = [] + self.predictions: List[ClipDetections] = [] def test_step(self, batch: TestExample, batch_idx: int): dataset = self.get_dataset() @@ -39,9 +39,9 @@ class EvaluationModule(LightningModule): start_times=[ca.clip.start_time for ca in clip_annotations], ) predictions = [ - BatDetect2Prediction( + ClipDetections( clip=clip_annotation.clip, - predictions=to_raw_predictions( + detections=to_raw_predictions( clip_dets.numpy(), targets=self.evaluator.targets, ), diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index 13c7779..97b4134 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -21,7 +21,7 @@ from batdetect2.evaluate.metrics.common import ( average_precision, compute_precision_recall, ) -from batdetect2.typing import RawPrediction, TargetProtocol +from batdetect2.typing import Detection, TargetProtocol __all__ = [ "ClassificationMetric", @@ -35,7 +35,7 @@ __all__ = [ class MatchEval: clip: data.Clip gt: data.SoundEventAnnotation | None - pred: RawPrediction | None + pred: Detection | None is_prediction: bool is_ground_truth: bool diff --git a/src/batdetect2/evaluate/metrics/clip_detection.py b/src/batdetect2/evaluate/metrics/clip_detection.py index 1bab7b3..7613228 100644 --- a/src/batdetect2/evaluate/metrics/clip_detection.py +++ b/src/batdetect2/evaluate/metrics/clip_detection.py @@ -159,7 +159,10 @@ class ClipDetectionPrecision: ClipDetectionMetricConfig = Annotated[ - ClipDetectionAveragePrecisionConfig | ClipDetectionROCAUCConfig | ClipDetectionRecallConfig | ClipDetectionPrecisionConfig, + ClipDetectionAveragePrecisionConfig + | ClipDetectionROCAUCConfig + | ClipDetectionRecallConfig + | ClipDetectionPrecisionConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/metrics/detection.py b/src/batdetect2/evaluate/metrics/detection.py index 2aba6fa..d2e5a15 100644 --- a/src/batdetect2/evaluate/metrics/detection.py +++ b/src/batdetect2/evaluate/metrics/detection.py @@ -15,7 +15,7 @@ from soundevent import data from batdetect2.core import BaseConfig, Registry from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing import RawPrediction +from batdetect2.typing import Detection __all__ = [ "DetectionMetricConfig", @@ -27,7 +27,7 @@ __all__ = [ @dataclass class MatchEval: gt: data.SoundEventAnnotation | None - pred: RawPrediction | None + pred: Detection | None is_prediction: bool is_ground_truth: bool diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py index c4d6514..e1f00d1 100644 --- a/src/batdetect2/evaluate/metrics/top_class.py +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -15,7 +15,7 @@ from soundevent import data from batdetect2.core import BaseConfig, Registry from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing import RawPrediction +from batdetect2.typing import Detection from batdetect2.typing.targets import TargetProtocol __all__ = [ @@ -29,7 +29,7 @@ __all__ = [ class MatchEval: clip: data.Clip gt: data.SoundEventAnnotation | None - pred: RawPrediction | None + pred: Detection | None is_ground_truth: bool is_generic: bool @@ -298,7 +298,11 @@ class BalancedAccuracy: TopClassMetricConfig = Annotated[ - TopClassAveragePrecisionConfig | TopClassROCAUCConfig | TopClassRecallConfig | TopClassPrecisionConfig | BalancedAccuracyConfig, + TopClassAveragePrecisionConfig + | TopClassROCAUCConfig + | TopClassRecallConfig + | TopClassPrecisionConfig + | BalancedAccuracyConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/base.py b/src/batdetect2/evaluate/plots/base.py index 10c74e0..beceb42 100644 --- a/src/batdetect2/evaluate/plots/base.py +++ b/src/batdetect2/evaluate/plots/base.py @@ -1,4 +1,3 @@ - import matplotlib.pyplot as plt from matplotlib.figure import Figure diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py index a84d08c..881ed93 100644 --- a/src/batdetect2/evaluate/plots/classification.py +++ b/src/batdetect2/evaluate/plots/classification.py @@ -322,7 +322,10 @@ class ROCCurve(BasePlot): ClassificationPlotConfig = Annotated[ - PRCurveConfig | ROCCurveConfig | ThresholdPrecisionCurveConfig | ThresholdRecallCurveConfig, + PRCurveConfig + | ROCCurveConfig + | ThresholdPrecisionCurveConfig + | ThresholdRecallCurveConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/detection.py b/src/batdetect2/evaluate/plots/detection.py index 22eb252..de0e309 100644 --- a/src/batdetect2/evaluate/plots/detection.py +++ b/src/batdetect2/evaluate/plots/detection.py @@ -290,7 +290,10 @@ class ExampleDetectionPlot(BasePlot): DetectionPlotConfig = Annotated[ - PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig | ExampleDetectionPlotConfig, + PRCurveConfig + | ROCCurveConfig + | ScoreDistributionPlotConfig + | ExampleDetectionPlotConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py index 2ca9bf9..5678b0f 100644 --- a/src/batdetect2/evaluate/plots/top_class.py +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -346,7 +346,10 @@ class ExampleClassificationPlot(BasePlot): TopClassPlotConfig = Annotated[ - PRCurveConfig | ROCCurveConfig | ConfusionMatrixConfig | ExampleClassificationPlotConfig, + PRCurveConfig + | ROCCurveConfig + | ConfusionMatrixConfig + | ExampleClassificationPlotConfig, Field(discriminator="name"), ] @@ -403,7 +406,8 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5): return matches indices, pred_scores = zip( - *[(index, match.score) for index, match in enumerate(matches)], strict=False + *[(index, match.score) for index, match in enumerate(matches)], + strict=False, ) bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") diff --git a/src/batdetect2/evaluate/tasks/__init__.py b/src/batdetect2/evaluate/tasks/__init__.py index 3c2e74e..625524e 100644 --- a/src/batdetect2/evaluate/tasks/__init__.py +++ b/src/batdetect2/evaluate/tasks/__init__.py @@ -13,7 +13,7 @@ from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig from batdetect2.targets import build_targets from batdetect2.typing import ( - BatDetect2Prediction, + ClipDetections, EvaluatorProtocol, TargetProtocol, ) @@ -45,7 +45,7 @@ def build_task( def evaluate_task( clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], task: Optional["str"] = None, targets: TargetProtocol | None = None, config: TaskConfig | dict | None = None, diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 9523ff5..975bdb8 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -22,9 +22,9 @@ from batdetect2.evaluate.affinity import ( ) from batdetect2.typing import ( AffinityFunction, - BatDetect2Prediction, + ClipDetections, EvaluatorProtocol, - RawPrediction, + Detection, TargetProtocol, ) @@ -96,7 +96,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], ) -> List[T_Output]: return [ self.evaluate_clip(clip_annotation, preds) @@ -108,7 +108,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - prediction: BatDetect2Prediction, + prediction: ClipDetections, ) -> T_Output: ... def include_sound_event_annotation( @@ -128,7 +128,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): def include_prediction( self, - prediction: RawPrediction, + prediction: Detection, clip: data.Clip, ) -> bool: return is_in_bounds( diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index ca7185b..1da934d 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -22,8 +22,8 @@ from batdetect2.evaluate.tasks.base import ( tasks_registry, ) from batdetect2.typing import ( - BatDetect2Prediction, - RawPrediction, + ClipDetections, + Detection, TargetProtocol, ) @@ -51,13 +51,13 @@ class ClassificationTask(BaseSEDTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - prediction: BatDetect2Prediction, + prediction: ClipDetections, ) -> ClipEval: clip = clip_annotation.clip preds = [ pred - for pred in prediction.predictions + for pred in prediction.detections if self.include_prediction(pred, clip) ] @@ -140,7 +140,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]): ) -def get_class_score(pred: RawPrediction, class_idx: int) -> float: +def get_class_score(pred: Detection, class_idx: int) -> float: return pred.class_scores[class_idx] diff --git a/src/batdetect2/evaluate/tasks/clip_classification.py b/src/batdetect2/evaluate/tasks/clip_classification.py index 67b2c71..32e2383 100644 --- a/src/batdetect2/evaluate/tasks/clip_classification.py +++ b/src/batdetect2/evaluate/tasks/clip_classification.py @@ -19,7 +19,7 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import BatDetect2Prediction, TargetProtocol +from batdetect2.typing import ClipDetections, TargetProtocol class ClipClassificationTaskConfig(BaseTaskConfig): @@ -37,7 +37,7 @@ class ClipClassificationTask(BaseTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - prediction: BatDetect2Prediction, + prediction: ClipDetections, ) -> ClipEval: clip = clip_annotation.clip @@ -54,7 +54,7 @@ class ClipClassificationTask(BaseTask[ClipEval]): gt_classes.add(class_name) pred_scores = defaultdict(float) - for pred in prediction.predictions: + for pred in prediction.detections: if not self.include_prediction(pred, clip): continue diff --git a/src/batdetect2/evaluate/tasks/clip_detection.py b/src/batdetect2/evaluate/tasks/clip_detection.py index ddb9d1a..b19efa0 100644 --- a/src/batdetect2/evaluate/tasks/clip_detection.py +++ b/src/batdetect2/evaluate/tasks/clip_detection.py @@ -18,7 +18,7 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import BatDetect2Prediction, TargetProtocol +from batdetect2.typing import ClipDetections, TargetProtocol class ClipDetectionTaskConfig(BaseTaskConfig): @@ -36,7 +36,7 @@ class ClipDetectionTask(BaseTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - prediction: BatDetect2Prediction, + prediction: ClipDetections, ) -> ClipEval: clip = clip_annotation.clip @@ -46,7 +46,7 @@ class ClipDetectionTask(BaseTask[ClipEval]): ) pred_score = 0 - for pred in prediction.predictions: + for pred in prediction.detections: if not self.include_prediction(pred, clip): continue diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py index 9ac9535..2c4e83f 100644 --- a/src/batdetect2/evaluate/tasks/detection.py +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -21,7 +21,7 @@ from batdetect2.evaluate.tasks.base import ( tasks_registry, ) from batdetect2.typing import TargetProtocol -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing.postprocess import ClipDetections class DetectionTaskConfig(BaseSEDTaskConfig): @@ -37,7 +37,7 @@ class DetectionTask(BaseSEDTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - prediction: BatDetect2Prediction, + prediction: ClipDetections, ) -> ClipEval: clip = clip_annotation.clip @@ -48,7 +48,7 @@ class DetectionTask(BaseSEDTask[ClipEval]): ] preds = [ pred - for pred in prediction.predictions + for pred in prediction.detections if self.include_prediction(pred, clip) ] diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index 00b1a73..0745891 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -20,7 +20,7 @@ from batdetect2.evaluate.tasks.base import ( BaseSEDTaskConfig, tasks_registry, ) -from batdetect2.typing import BatDetect2Prediction, TargetProtocol +from batdetect2.typing import ClipDetections, TargetProtocol class TopClassDetectionTaskConfig(BaseSEDTaskConfig): @@ -36,7 +36,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - prediction: BatDetect2Prediction, + prediction: ClipDetections, ) -> ClipEval: clip = clip_annotation.clip @@ -47,7 +47,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]): ] preds = [ pred - for pred in prediction.predictions + for pred in prediction.detections if self.include_prediction(pred, clip) ] diff --git a/src/batdetect2/finetune/finetune_model.py b/src/batdetect2/finetune/finetune_model.py index 315b02a..5e4db80 100644 --- a/src/batdetect2/finetune/finetune_model.py +++ b/src/batdetect2/finetune/finetune_model.py @@ -88,7 +88,8 @@ def select_device(warn=True) -> str: if warn: warnings.warn( "No GPU available, using the CPU instead. Please consider using a GPU " - "to speed up training.", stacklevel=2 + "to speed up training.", + stacklevel=2, ) return "cpu" diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index b97848d..a32549f 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -10,7 +10,7 @@ from batdetect2.inference.lightning import InferenceModule from batdetect2.models import Model from batdetect2.preprocess.preprocessor import build_preprocessor from batdetect2.targets.targets import build_targets -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing.postprocess import ClipDetections if TYPE_CHECKING: from batdetect2.config import BatDetect2Config @@ -30,7 +30,7 @@ def run_batch_inference( config: Optional["BatDetect2Config"] = None, num_workers: int | None = None, batch_size: int | None = None, -) -> List[BatDetect2Prediction]: +) -> List[ClipDetections]: from batdetect2.config import BatDetect2Config config = config or BatDetect2Config() @@ -70,7 +70,7 @@ def process_file_list( audio_loader: Optional["AudioLoader"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, num_workers: int | None = None, -) -> List[BatDetect2Prediction]: +) -> List[ClipDetections]: clip_config = config.inference.clipping clips = get_clips_from_files( paths, diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index e36224b..1cd07bd 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.models import Model from batdetect2.postprocess import to_raw_predictions -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing.postprocess import ClipDetections class InferenceModule(LightningModule): @@ -19,7 +19,7 @@ class InferenceModule(LightningModule): batch: DatasetItem, batch_idx: int, dataloader_idx: int = 0, - ) -> Sequence[BatDetect2Prediction]: + ) -> Sequence[ClipDetections]: dataset = self.get_dataset() clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx] @@ -32,9 +32,9 @@ class InferenceModule(LightningModule): ) predictions = [ - BatDetect2Prediction( + ClipDetections( clip=clip, - predictions=to_raw_predictions( + detections=to_raw_predictions( clip_dets.numpy(), targets=self.model.targets, ), diff --git a/src/batdetect2/logging.py b/src/batdetect2/logging.py index aadd92c..62dcedd 100644 --- a/src/batdetect2/logging.py +++ b/src/batdetect2/logging.py @@ -76,7 +76,10 @@ class MLFlowLoggerConfig(BaseLoggerConfig): LoggerConfig = Annotated[ - DVCLiveConfig | CSVLoggerConfig | TensorBoardLoggerConfig | MLFlowLoggerConfig, + DVCLiveConfig + | CSVLoggerConfig + | TensorBoardLoggerConfig + | MLFlowLoggerConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/models/config.py b/src/batdetect2/models/config.py index 84ebef1..999b777 100644 --- a/src/batdetect2/models/config.py +++ b/src/batdetect2/models/config.py @@ -1,4 +1,3 @@ - from soundevent import data from batdetect2.core.configs import BaseConfig, load_config diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index a79524b..e7c3e91 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -41,7 +41,10 @@ __all__ = [ ] DecoderLayerConfig = Annotated[ - ConvConfig | FreqCoordConvUpConfig | StandardConvUpConfig | LayerGroupConfig, + ConvConfig + | FreqCoordConvUpConfig + | StandardConvUpConfig + | LayerGroupConfig, Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Decoder.""" diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 56c4ea9..37eeec6 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. """ - import torch from loguru import logger diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index 4081260..a302992 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -43,7 +43,10 @@ __all__ = [ ] EncoderLayerConfig = Annotated[ - ConvConfig | FreqCoordConvDownConfig | StandardConvDownConfig | LayerGroupConfig, + ConvConfig + | FreqCoordConvDownConfig + | StandardConvDownConfig + | LayerGroupConfig, Field(discriminator="name"), ] """Type alias for the discriminated union of block configs usable in Encoder.""" diff --git a/src/batdetect2/plotting/detections.py b/src/batdetect2/plotting/detections.py index d338483..797da6d 100644 --- a/src/batdetect2/plotting/detections.py +++ b/src/batdetect2/plotting/detections.py @@ -1,4 +1,3 @@ - from matplotlib import axes, patches from soundevent.plot import plot_geometry diff --git a/src/batdetect2/plotting/gallery.py b/src/batdetect2/plotting/gallery.py index a6fc906..4cc9eeb 100644 --- a/src/batdetect2/plotting/gallery.py +++ b/src/batdetect2/plotting/gallery.py @@ -36,7 +36,9 @@ def plot_match_gallery( sharey="row", ) - for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False): + for tp_ax, tp_match in zip( + axes[0], true_positives[:n_examples], strict=False + ): try: plot_true_positive_match( tp_match, @@ -53,7 +55,9 @@ def plot_match_gallery( ): continue - for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False): + for fp_ax, fp_match in zip( + axes[1], false_positives[:n_examples], strict=False + ): try: plot_false_positive_match( fp_match, @@ -70,7 +74,9 @@ def plot_match_gallery( ): continue - for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False): + for fn_ax, fn_match in zip( + axes[2], false_negatives[:n_examples], strict=False + ): try: plot_false_negative_match( fn_match, @@ -87,7 +93,9 @@ def plot_match_gallery( ): continue - for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False): + for ct_ax, ct_match in zip( + axes[3], cross_triggers[:n_examples], strict=False + ): try: plot_cross_trigger_match( ct_match, diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index c2733d9..266ff36 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -8,7 +8,7 @@ from batdetect2.plotting.clips import plot_clip from batdetect2.typing import ( AudioLoader, PreprocessorProtocol, - RawPrediction, + Detection, ) __all__ = [ @@ -22,7 +22,7 @@ __all__ = [ class MatchProtocol(Protocol): clip: data.Clip gt: data.SoundEventAnnotation | None - pred: RawPrediction | None + pred: Detection | None score: float true_class: str | None @@ -341,4 +341,3 @@ def plot_cross_trigger_match( ax.set_title("Cross Trigger") return ax - diff --git a/src/batdetect2/postprocess/clips.py b/src/batdetect2/postprocess/clips.py new file mode 100644 index 0000000..5b46bbd --- /dev/null +++ b/src/batdetect2/postprocess/clips.py @@ -0,0 +1,6 @@ +from batdetect2.typing import ClipDetections + + +class ClipTransform: + def __init__(self, clip: ClipDetections): + pass diff --git a/src/batdetect2/postprocess/config.py b/src/batdetect2/postprocess/config.py index 4046d59..7f7e297 100644 --- a/src/batdetect2/postprocess/config.py +++ b/src/batdetect2/postprocess/config.py @@ -1,4 +1,3 @@ - from pydantic import Field from soundevent import data diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 2974771..d779280 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -7,7 +7,7 @@ from soundevent import data from batdetect2.typing.postprocess import ( ClipDetectionsArray, - RawPrediction, + Detection, ) from batdetect2.typing.targets import TargetProtocol @@ -30,7 +30,7 @@ decoding. def to_raw_predictions( detections: ClipDetectionsArray, targets: TargetProtocol, -) -> List[RawPrediction]: +) -> List[Detection]: predictions = [] for score, class_scores, time, freq, dims, feats in zip( @@ -39,7 +39,8 @@ def to_raw_predictions( detections.times, detections.frequencies, detections.sizes, - detections.features, strict=False, + detections.features, + strict=False, ): highest_scoring_class = targets.class_names[class_scores.argmax()] @@ -50,7 +51,7 @@ def to_raw_predictions( ) predictions.append( - RawPrediction( + Detection( detection_score=score, geometry=geom, class_scores=class_scores, @@ -62,7 +63,7 @@ def to_raw_predictions( def convert_raw_predictions_to_clip_prediction( - raw_predictions: List[RawPrediction], + raw_predictions: List[Detection], clip: data.Clip, targets: TargetProtocol, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, @@ -85,7 +86,7 @@ def convert_raw_predictions_to_clip_prediction( def convert_raw_prediction_to_sound_event_prediction( - raw_prediction: RawPrediction, + raw_prediction: Detection, recording: data.Recording, targets: TargetProtocol, classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, diff --git a/src/batdetect2/preprocess/preprocessor.py b/src/batdetect2/preprocess/preprocessor.py index 47569a8..f5f351b 100644 --- a/src/batdetect2/preprocess/preprocessor.py +++ b/src/batdetect2/preprocess/preprocessor.py @@ -1,4 +1,3 @@ - import torch from loguru import logger diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 926055b..d248fa1 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -377,7 +377,10 @@ class PeakNormalize(torch.nn.Module): SpectrogramTransform = Annotated[ - PcenConfig | ScaleAmplitudeConfig | SpectralMeanSubstractionConfig | PeakNormalizeConfig, + PcenConfig + | ScaleAmplitudeConfig + | SpectralMeanSubstractionConfig + | PeakNormalizeConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 8deff36..ba2e9ca 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -394,7 +394,8 @@ class MaskTime(torch.nn.Module): size=num_masks, ) masks = [ - (start, start + size) for start, size in zip(mask_start, mask_size, strict=False) + (start, start + size) + for start, size in zip(mask_start, mask_size, strict=False) ] return mask_time(spec, masks), clip_annotation @@ -460,7 +461,8 @@ class MaskFrequency(torch.nn.Module): size=num_masks, ) masks = [ - (start, start + size) for start, size in zip(mask_start, mask_size, strict=False) + (start, start + size) + for start, size in zip(mask_start, mask_size, strict=False) ] return mask_frequency(spec, masks), clip_annotation @@ -498,7 +500,12 @@ SpectrogramAugmentationConfig = Annotated[ ] AugmentationConfig = Annotated[ - MixAudioConfig | AddEchoConfig | ScaleVolumeConfig | WarpConfig | MaskFrequencyConfig | MaskTimeConfig, + MixAudioConfig + | AddEchoConfig + | ScaleVolumeConfig + | WarpConfig + | MaskFrequencyConfig + | MaskTimeConfig, Field(discriminator="name"), ] """Type alias for the discriminated union of individual augmentation config.""" diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 2be538f..51c1c59 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -10,7 +10,7 @@ from batdetect2.postprocess import to_raw_predictions from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule from batdetect2.typing import ( - BatDetect2Prediction, + ClipDetections, EvaluatorProtocol, ModelOutput, TrainExample, @@ -24,7 +24,7 @@ class ValidationMetrics(Callback): self.evaluator = evaluator self._clip_annotations: List[data.ClipAnnotation] = [] - self._predictions: List[BatDetect2Prediction] = [] + self._predictions: List[ClipDetections] = [] def get_dataset(self, trainer: Trainer) -> ValidationDataset: dataloaders = trainer.val_dataloaders @@ -100,9 +100,9 @@ class ValidationMetrics(Callback): start_times=[ca.clip.start_time for ca in clip_annotations], ) predictions = [ - BatDetect2Prediction( + ClipDetections( clip=clip_annotation.clip, - predictions=to_raw_predictions( + detections=to_raw_predictions( clip_dets.numpy(), targets=model.targets ), ) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 0fc1797..083a63e 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -1,4 +1,3 @@ - from pydantic import Field from soundevent import data diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index 7b36f17..119a654 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -18,7 +18,6 @@ The primary entry points are: - `LossConfig`: The Pydantic model for configuring loss weights and parameters. """ - import numpy as np import torch import torch.nn.functional as F diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py index 6269528..bd24fae 100644 --- a/src/batdetect2/typing/__init__.py +++ b/src/batdetect2/typing/__init__.py @@ -10,11 +10,11 @@ from batdetect2.typing.evaluate import ( ) from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.postprocess import ( - BatDetect2Prediction, + ClipDetections, ClipDetectionsTensor, GeometryDecoder, PostprocessorProtocol, - RawPrediction, + Detection, ) from batdetect2.typing.preprocess import ( AudioLoader, @@ -44,7 +44,7 @@ __all__ = [ "AudioLoader", "Augmentation", "BackboneModel", - "BatDetect2Prediction", + "ClipDetections", "ClipDetectionsTensor", "ClipLabeller", "ClipMatches", @@ -65,7 +65,7 @@ __all__ = [ "PostprocessorProtocol", "PreprocessorProtocol", "ROITargetMapper", - "RawPrediction", + "Detection", "Size", "SoundEventDecoder", "SoundEventEncoder", diff --git a/src/batdetect2/typing/data.py b/src/batdetect2/typing/data.py index 77f534d..12f13fd 100644 --- a/src/batdetect2/typing/data.py +++ b/src/batdetect2/typing/data.py @@ -2,7 +2,7 @@ from typing import Generic, List, Protocol, Sequence, TypeVar from soundevent.data import PathLike -from batdetect2.typing.postprocess import BatDetect2Prediction +from batdetect2.typing.postprocess import ClipDetections __all__ = [ "OutputFormatterProtocol", @@ -12,9 +12,7 @@ T = TypeVar("T") class OutputFormatterProtocol(Protocol, Generic[T]): - def format( - self, predictions: Sequence[BatDetect2Prediction] - ) -> List[T]: ... + def format(self, predictions: Sequence[ClipDetections]) -> List[T]: ... def save( self, diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 56438ff..0d1c1d3 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -13,7 +13,7 @@ from typing import ( from matplotlib.figure import Figure from soundevent import data -from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction +from batdetect2.typing.postprocess import ClipDetections, Detection from batdetect2.typing.targets import TargetProtocol __all__ = [ @@ -84,7 +84,7 @@ Geom = TypeVar("Geom", bound=data.Geometry, contravariant=True) class AffinityFunction(Protocol): def __call__( self, - detection: RawPrediction, + detection: Detection, ground_truth: data.SoundEventAnnotation, ) -> float: ... @@ -93,7 +93,7 @@ class MetricsProtocol(Protocol): def __call__( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], + predictions: Sequence[Sequence[Detection]], ) -> Dict[str, float]: ... @@ -101,7 +101,7 @@ class PlotterProtocol(Protocol): def __call__( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], + predictions: Sequence[Sequence[Detection]], ) -> Iterable[Tuple[str, Figure]]: ... @@ -114,7 +114,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[BatDetect2Prediction], + predictions: Sequence[ClipDetections], ) -> EvaluationOutput: ... def compute_metrics( diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index ed7b9fb..5c95308 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -22,7 +22,7 @@ from batdetect2.typing.models import ModelOutput from batdetect2.typing.targets import Position, Size __all__ = [ - "RawPrediction", + "Detection", "PostprocessorProtocol", "GeometryDecoder", ] @@ -46,7 +46,8 @@ class GeometryDecoder(Protocol): ) -> data.Geometry: ... -class RawPrediction(NamedTuple): +@dataclass +class Detection: geometry: data.Geometry detection_score: float class_scores: np.ndarray @@ -82,9 +83,16 @@ class ClipDetectionsTensor(NamedTuple): @dataclass -class BatDetect2Prediction: +class ClipDetections: clip: data.Clip - predictions: List[RawPrediction] + detections: List[Detection] + + +@dataclass +class ClipPrediction: + clip: data.Clip + detection_score: float + class_scores: np.ndarray class PostprocessorProtocol(Protocol): diff --git a/src/batdetect2/utils/detector_utils.py b/src/batdetect2/utils/detector_utils.py index c1b8e72..629f669 100644 --- a/src/batdetect2/utils/detector_utils.py +++ b/src/batdetect2/utils/detector_utils.py @@ -220,7 +220,8 @@ def get_annotations_from_preds( predictions["high_freqs"], class_ind_best, class_prob_best, - predictions["det_probs"], strict=False, + predictions["det_probs"], + strict=False, ) ] return annotations diff --git a/src/batdetect2/utils/plot_utils.py b/src/batdetect2/utils/plot_utils.py index e43d4d5..359f088 100644 --- a/src/batdetect2/utils/plot_utils.py +++ b/src/batdetect2/utils/plot_utils.py @@ -87,9 +87,7 @@ def save_ann_spec( y_extent = [0, duration, min_freq, max_freq] plt.close("all") - plt.figure( - 0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100 - ) + plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100) plt.imshow( spec, aspect="auto", diff --git a/src/batdetect2/utils/tensors.py b/src/batdetect2/utils/tensors.py index 4301851..6436c4d 100644 --- a/src/batdetect2/utils/tensors.py +++ b/src/batdetect2/utils/tensors.py @@ -1,4 +1,3 @@ - import numpy as np import torch from torch.nn import functional as F diff --git a/tests/test_data/test_predictions/test_parquet.py b/tests/test_data/test_predictions/test_parquet.py index 75ea244..d4e38e0 100644 --- a/tests/test_data/test_predictions/test_parquet.py +++ b/tests/test_data/test_predictions/test_parquet.py @@ -10,8 +10,8 @@ from batdetect2.data.predictions import ( build_output_formatter, ) from batdetect2.typing import ( - BatDetect2Prediction, - RawPrediction, + ClipDetections, + Detection, TargetProtocol, ) @@ -31,7 +31,7 @@ def test_roundtrip( tmp_path: Path, ): detections = [ - RawPrediction( + Detection( geometry=data.BoundingBox( coordinates=list(np.random.uniform(size=[4])) ), @@ -44,7 +44,7 @@ def test_roundtrip( for _ in range(10) ] - prediction = BatDetect2Prediction(clip=clip, predictions=detections) + prediction = ClipDetections(clip=clip, detections=detections) path = tmp_path / "predictions.parquet" @@ -58,7 +58,7 @@ def test_roundtrip( assert recovered[0].clip == prediction.clip for recovered_prediction, detection in zip( - recovered[0].predictions, detections + recovered[0].detections, detections, strict=True ): assert ( recovered_prediction.detection_score == detection.detection_score @@ -79,9 +79,9 @@ def test_multiple_clips( ): # Create a second clip clip2 = clip.model_copy(update={"uuid": uuid4()}) - + detections1 = [ - RawPrediction( + Detection( geometry=data.BoundingBox( coordinates=list(np.random.uniform(size=[4])) ), @@ -92,9 +92,9 @@ def test_multiple_clips( features=np.random.uniform(size=32), ) ] - + detections2 = [ - RawPrediction( + Detection( geometry=data.BoundingBox( coordinates=list(np.random.uniform(size=[4])) ), @@ -107,19 +107,19 @@ def test_multiple_clips( ] predictions = [ - BatDetect2Prediction(clip=clip, predictions=detections1), - BatDetect2Prediction(clip=clip2, predictions=detections2), + ClipDetections(clip=clip, detections=detections1), + ClipDetections(clip=clip2, detections=detections2), ] path = tmp_path / "multi_predictions.parquet" sample_formatter.save(predictions=predictions, path=path) - + recovered = sample_formatter.load(path=path) - + assert len(recovered) == 2 # Order might not be preserved if we don't sort, but implementation appends so it should be # However, let's sort by clip uuid to be safe if needed, or just check existence - + recovered_uuids = {p.clip.uuid for p in recovered} expected_uuids = {clip.uuid, clip2.uuid} assert recovered_uuids == expected_uuids @@ -133,16 +133,18 @@ def test_complex_geometry( ): # Create a polygon geometry polygon = data.Polygon( - coordinates=[[ - [0.0, 10000.0], - [0.1, 20000.0], - [0.2, 10000.0], - [0.0, 10000.0], - ]] + coordinates=[ + [ + [0.0, 10000.0], + [0.1, 20000.0], + [0.2, 10000.0], + [0.0, 10000.0], + ] + ] ) - + detections = [ - RawPrediction( + Detection( geometry=polygon, detection_score=0.95, class_scores=np.random.uniform( @@ -152,18 +154,18 @@ def test_complex_geometry( ) ] - prediction = BatDetect2Prediction(clip=clip, predictions=detections) + prediction = ClipDetections(clip=clip, detections=detections) path = tmp_path / "complex_geometry.parquet" sample_formatter.save(predictions=[prediction], path=path) - + recovered = sample_formatter.load(path=path) - + assert len(recovered) == 1 - assert len(recovered[0].predictions) == 1 - - recovered_pred = recovered[0].predictions[0] - + assert len(recovered[0].detections) == 1 + + recovered_pred = recovered[0].detections[0] + # Check if geometry is recovered correctly as a Polygon assert isinstance(recovered_pred.geometry, data.Polygon) assert recovered_pred.geometry == polygon diff --git a/tests/test_data/test_predictions/test_raw.py b/tests/test_data/test_predictions/test_raw.py index cdfcafe..8e2f88d 100644 --- a/tests/test_data/test_predictions/test_raw.py +++ b/tests/test_data/test_predictions/test_raw.py @@ -6,8 +6,8 @@ from soundevent import data from batdetect2.data.predictions import RawOutputConfig, build_output_formatter from batdetect2.typing import ( - BatDetect2Prediction, - RawPrediction, + ClipDetections, + Detection, TargetProtocol, ) @@ -27,7 +27,7 @@ def test_roundtrip( tmp_path: Path, ): detections = [ - RawPrediction( + Detection( geometry=data.BoundingBox( coordinates=list(np.random.uniform(size=[4])) ), @@ -40,7 +40,7 @@ def test_roundtrip( for _ in range(10) ] - prediction = BatDetect2Prediction(clip=clip, predictions=detections) + prediction = ClipDetections(clip=clip, detections=detections) path = tmp_path / "predictions" @@ -52,7 +52,9 @@ def test_roundtrip( assert recovered[0].clip == prediction.clip for recovered_prediction, detection in zip( - recovered[0].predictions, detections + recovered[0].detections, + detections, + strict=True, ): assert ( recovered_prediction.detection_score == detection.detection_score diff --git a/tests/test_evaluate/__init__.py b/tests/test_evaluate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_evaluate/test_tasks/__init__.py b/tests/test_evaluate/test_tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_evaluate/test_tasks/conftest.py b/tests/test_evaluate/test_tasks/conftest.py new file mode 100644 index 0000000..8999d83 --- /dev/null +++ b/tests/test_evaluate/test_tasks/conftest.py @@ -0,0 +1,86 @@ +from typing import Literal + +import numpy as np +import pytest +from soundevent import data + +from batdetect2.typing import Detection + + +@pytest.fixture +def clip(recording: data.Recording) -> data.Clip: + return data.Clip(recording=recording, start_time=0, end_time=100) + + +@pytest.fixture +def create_detection(): + def factory( + detection_score: float = 0.5, + start_time: float = 0.1, + duration: float = 0.01, + low_freq: float = 40_000, + bandwidth: float = 30_000, + pip_score: float = 0, + myo_score: float = 0, + ): + return Detection( + detection_score=detection_score, + class_scores=np.array( + [ + pip_score, + myo_score, + ] + ), + features=np.zeros([32]), + geometry=data.BoundingBox( + coordinates=[ + start_time, + low_freq, + start_time + duration, + low_freq + bandwidth, + ] + ), + ) + + return factory + + +@pytest.fixture +def create_annotation( + clip: data.Clip, + bat_tag: data.Tag, + myomyo_tag: data.Tag, + pippip_tag: data.Tag, +): + def factory( + start_time: float = 0.1, + duration: float = 0.01, + low_freq: float = 40_000, + bandwidth: float = 30_000, + is_target: bool = True, + class_name: Literal["pippip", "myomyo"] | None = None, + ): + tags = [bat_tag] if is_target else [] + + if class_name is not None: + if class_name == "pippip": + tags.append(pippip_tag) + elif class_name == "myomyo": + tags.append(myomyo_tag) + + return data.SoundEventAnnotation( + sound_event=data.SoundEvent( + geometry=data.BoundingBox( + coordinates=[ + start_time, + low_freq, + start_time + duration, + low_freq + bandwidth, + ] + ), + recording=clip.recording, + ), + tags=tags, + ) + + return factory diff --git a/tests/test_evaluate/test_tasks/test_classification.py b/tests/test_evaluate/test_tasks/test_classification.py new file mode 100644 index 0000000..0647240 --- /dev/null +++ b/tests/test_evaluate/test_tasks/test_classification.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest +from soundevent import data + +from batdetect2.evaluate.tasks import build_task +from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig +from batdetect2.typing import ClipDetections +from batdetect2.typing.targets import TargetProtocol + + +def test_classification( + clip: data.Clip, + sample_targets: TargetProtocol, + create_detection, + create_annotation, +): + config = ClassificationTaskConfig.model_validate( + { + "name": "sound_event_classification", + "metrics": [{"name": "average_precision"}], + } + ) + + evaluator = build_task(config, targets=sample_targets) + + # Create a dummy prediction + prediction = ClipDetections( + clip=clip, + detections=[ + create_detection( + start_time=1 + 0.1 * index, + pip_score=score, + ) + for index, score in enumerate(np.linspace(0, 1, 100)) + ] + + [ + create_detection( + start_time=1.05 + 0.1 * index, + myo_score=score, + ) + for index, score in enumerate(np.linspace(1, 0, 100)) + ], + ) + + # Create a dummy annotation + gt = data.ClipAnnotation( + clip=clip, + sound_events=[ + create_annotation( + start_time=1 + 0.1 * index, + is_target=index % 2 == 0, + class_name="pippip", + ) + for index in range(100) + ] + + [ + create_annotation( + start_time=1.05 + 0.1 * index, + is_target=index % 3 == 0, + class_name="myomyo", + ) + for index in range(100) + ], + ) + + evals = evaluator.evaluate([gt], [prediction]) + metrics = evaluator.compute_metrics(evals) + + assert metrics["classification/average_precision/pippip"] == pytest.approx( + 0.5, abs=0.005 + ) + assert metrics["classification/average_precision/myomyo"] == pytest.approx( + 0.371, abs=0.005 + ) + assert metrics["classification/mean_average_precision"] == pytest.approx( + (0.5 + 0.371) / 2, abs=0.005 + ) diff --git a/tests/test_evaluate/test_tasks/test_detection.py b/tests/test_evaluate/test_tasks/test_detection.py new file mode 100644 index 0000000..5bbaa71 --- /dev/null +++ b/tests/test_evaluate/test_tasks/test_detection.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest +from soundevent import data + +from batdetect2.evaluate.tasks import build_task +from batdetect2.evaluate.tasks.detection import DetectionTaskConfig +from batdetect2.typing import ClipDetections +from batdetect2.typing.targets import TargetProtocol + + +def test_detection( + clip: data.Clip, + sample_targets: TargetProtocol, + create_detection, + create_annotation, +): + config = DetectionTaskConfig.model_validate( + { + "name": "sound_event_detection", + "metrics": [{"name": "average_precision"}], + } + ) + evaluator = build_task(config, targets=sample_targets) + + # Create a dummy prediction + prediction = ClipDetections( + clip=clip, + detections=[ + create_detection(start_time=1 + 0.1 * index, detection_score=score) + for index, score in enumerate(np.linspace(0, 1, 100)) + ], + ) + + # Create a dummy annotation + gt = data.ClipAnnotation( + clip=clip, + sound_events=[ + # Only half of the annotations are targets + create_annotation( + start_time=1 + 0.1 * index, + is_target=index % 2 == 0, + ) + for index in range(100) + ], + ) + + # Run the task + evals = evaluator.evaluate([gt], [prediction]) + metrics = evaluator.compute_metrics(evals) + assert metrics["detection/average_precision"] == pytest.approx(0.5) diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index fa57732..ae79818 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -14,7 +14,7 @@ from batdetect2.postprocess.decoding import ( get_generic_tags, get_prediction_features, ) -from batdetect2.typing import RawPrediction, TargetProtocol +from batdetect2.typing import Detection, TargetProtocol @pytest.fixture @@ -209,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset: @pytest.fixture -def sample_raw_predictions() -> List[RawPrediction]: +def sample_raw_predictions() -> List[Detection]: """Manually crafted RawPrediction objects using the actual type.""" pred1_classes = xr.DataArray( @@ -220,7 +220,7 @@ def sample_raw_predictions() -> List[RawPrediction]: coords={"feature": ["f0", "f1", "f2", "f3"]}, dims=["feature"], ) - pred1 = RawPrediction( + pred1 = Detection( detection_score=0.9, geometry=data.BoundingBox( coordinates=[ @@ -242,7 +242,7 @@ def sample_raw_predictions() -> List[RawPrediction]: coords={"feature": ["f0", "f1", "f2", "f3"]}, dims=["feature"], ) - pred2 = RawPrediction( + pred2 = Detection( detection_score=0.8, geometry=data.BoundingBox( coordinates=[ @@ -264,7 +264,7 @@ def sample_raw_predictions() -> List[RawPrediction]: coords={"feature": ["f0", "f1", "f2", "f3"]}, dims=["feature"], ) - pred3 = RawPrediction( + pred3 = Detection( detection_score=0.15, geometry=data.BoundingBox( coordinates=[ @@ -281,7 +281,7 @@ def sample_raw_predictions() -> List[RawPrediction]: def test_convert_raw_to_sound_event_basic( - sample_raw_predictions: List[RawPrediction], + sample_raw_predictions: List[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -324,7 +324,7 @@ def test_convert_raw_to_sound_event_basic( def test_convert_raw_to_sound_event_thresholding( - sample_raw_predictions: List[RawPrediction], + sample_raw_predictions: List[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -352,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding( def test_convert_raw_to_sound_event_no_threshold( - sample_raw_predictions: List[RawPrediction], + sample_raw_predictions: List[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -380,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold( def test_convert_raw_to_sound_event_top_class( - sample_raw_predictions: List[RawPrediction], + sample_raw_predictions: List[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -407,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class( def test_convert_raw_to_sound_event_all_below_threshold( - sample_raw_predictions: List[RawPrediction], + sample_raw_predictions: List[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -433,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold( def test_convert_raw_list_to_clip_basic( - sample_raw_predictions: List[RawPrediction], + sample_raw_predictions: List[Detection], sample_clip: data.Clip, dummy_targets: TargetProtocol, ): @@ -485,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets): def test_convert_raw_list_to_clip_passes_args( - sample_raw_predictions: List[RawPrediction], + sample_raw_predictions: List[Detection], sample_clip: data.Clip, dummy_targets: TargetProtocol, ): diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index b79fa48..f61bcff 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -1,4 +1,3 @@ - import numpy as np import pytest import xarray as xr